classification


Loss function for class imbalanced binary classifier in Tensor flow


I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like:
minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)
Appreciate any pointers on how I can build this logic.
You can add class weights to the loss function, by multiplying logits.
Regular cross entropy loss is this:
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
= -x[class] + log(\sum_j exp(x[j]))
in weighted case:
loss(x, class) = weights[class] * (-x[class] + log(\sum_j exp(x[j])))
So by multiplying logits, you are re-scaling predictions of each class by its class weight.
For example:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([ratio, 1.0 - ratio])
logits = ... # shape [batch_size, 2]
weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
xent = tf.nn.softmax_cross_entropy_with_logits(
weighted_logits, labels, name="xent_raw")
For a complete solution, you can use this op in skflow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
The code you proposed seems wrong to me.
The loss should be multiplied by the weight, I agree.
But if you multiply the logit by the class weights, you end with:
weights[class] * -x[class] + log( \sum_j exp(x[j] * weights[class]) )
The second term is not equal to:
weights[class] * log(\sum_j exp(x[j]))
To show this, we can be rewrite the latter as:
log( (\sum_j exp(x[j]) ^ weights[class] )
So here is the code I'm proposing:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])
logits = ... # shape [batch_size, 2]
weight_per_label = tf.transpose( tf.matmul(labels
, tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label
xent = tf.mul(weight_per_label
, tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1
Use tf.nn.weighted_cross_entropy_with_logits() and set pos_weight to 1 / (expected ratio of positives).
Did ops tf.nn.weighted_cross_entropy_with_logits() for two classes:
classes_weights = tf.constant([0.1, 1.0])
cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=labels, pos_weight=classes_weights)

Related Links

caffe: Confused about regression
How to cut a dendrogram in r
Building weka classifier
Does Orange data mining software has multi-layer perceptron classification?
User Classification in RapidMiner - output should be the user based on a fed test data
Error in building mean image file(Caffe)
caffe: probability distribution for regression / expanding classification (softmax layer) to allow 3D output
Does MLE produce a generative or discriminative classifier?
Basic Hidden Markov Model, Viterbi algorithm
Where do I write the code for LIBSVM?
How to understand the output of ADTree classification in WEKA
Issues regarding classification instead of regression using deep learing
Caffe produces negative loss values (Multi label classification with lmdb)
ibm watson document classification
Sparse Representation Classifier Accuracy
Multi-Class Classification in Caffe of HDF5 data

Categories

HOME
qemu
wifi
ibm
glsl
dynamic
conceptual
data-modeling
datetimepicker
css-selectors
powerquery
gnuradio
liquidsoap
open-search-server
markdown
ldap-query
xmlhttprequest
chef-recipe
project-structure
rndis
ng-tags-input
infinite-loop
functional-dependencies
aspxgridview
chartist.js
android-6.0-marshmallow
ds-5
apm
jboss-arquillian
vtd-xml
language-detection
aweber
mapquest
iterm2
tapply
susy
directadmin
uitextview
libzip
taiga
piranha-cms
ocamlbuild
standard-error
sharpdevelop
cc
totalview
jtextarea
petapoco
vertex-shader
papyrus
avaudiorecorder
pyshark
ooad
askbot
mojolicious
tess4j
dpkt
android-filterable
borland-c++
page-refresh
genome
screen-lock
ultraedit
beagleboard
service-broker
method-overloading
launcher
information-hiding
sync
facebook-java-api
svg-android
zend-db-table
record-locking
maven-webstart-plugin
http-compression
msbuild-task
spawn
generator-expression
string-length
cryptarithmetic-puzzle
expresso-store
netzke
doxia
flatten
diazo
animationdrawable
apache-commons-email
clgeocoder
jqtransform
object-tag
moss2007-security
nbehave
simultaneous
uccapi
retrospectiva
opengl-to-opengles
projectgen
cstring
office-2003

Resources

Mobile Apps Dev
Database Users
javascript
java
csharp
php
android
MS Developer
developer works
python
ios
c
html
jquery
RDBMS discuss
Cloud Virtualization
Database Dev&Adm
javascript
java
csharp
php
python
android
jquery
ruby
ios
html
Mobile App
Mobile App
Mobile App