classification


Loss function for class imbalanced binary classifier in Tensor flow


I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like:
minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)
Appreciate any pointers on how I can build this logic.
You can add class weights to the loss function, by multiplying logits.
Regular cross entropy loss is this:
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
= -x[class] + log(\sum_j exp(x[j]))
in weighted case:
loss(x, class) = weights[class] * (-x[class] + log(\sum_j exp(x[j])))
So by multiplying logits, you are re-scaling predictions of each class by its class weight.
For example:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([ratio, 1.0 - ratio])
logits = ... # shape [batch_size, 2]
weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
xent = tf.nn.softmax_cross_entropy_with_logits(
weighted_logits, labels, name="xent_raw")
For a complete solution, you can use this op in skflow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
The code you proposed seems wrong to me.
The loss should be multiplied by the weight, I agree.
But if you multiply the logit by the class weights, you end with:
weights[class] * -x[class] + log( \sum_j exp(x[j] * weights[class]) )
The second term is not equal to:
weights[class] * log(\sum_j exp(x[j]))
To show this, we can be rewrite the latter as:
log( (\sum_j exp(x[j]) ^ weights[class] )
So here is the code I'm proposing:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])
logits = ... # shape [batch_size, 2]
weight_per_label = tf.transpose( tf.matmul(labels
, tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label
xent = tf.mul(weight_per_label
, tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1
Use tf.nn.weighted_cross_entropy_with_logits() and set pos_weight to 1 / (expected ratio of positives).
Did ops tf.nn.weighted_cross_entropy_with_logits() for two classes:
classes_weights = tf.constant([0.1, 1.0])
cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=labels, pos_weight=classes_weights)

Related Links

voting with average of probabilities in weka
Weka : how to use cross validation in code
Decision Tree relevent classification for this task?
Accuracy of a naive bayes classifier
Weka library java: how to get the prospect of a classification?
Multilabel Text Classification NLTK
Loss function for class imbalanced binary classifier in Tensor flow
can we use GMDH for two or three class classiication
How to normalize close range data?
Query about NaiveBayes Classifier
Suggested unsupervised feature selection / extraction method for 2 class classification?
WEKA - Classification - Training and Test Set
Chromosome representation in GA and DEAP
How to extract support vectors from SVMLight model
Stanford Classifier: generating model for on the fly classification (eg big data stream)?
Classification model using xgboost package

Categories

HOME
intellij-idea
typo3
pyspark
abc
xamarin.forms
window.open
dynamic
arcgis
okhttp3
arguments
singleton
ebay
webpack-dev-server
obd-ii
software-collections
apache-httpclient-4.x
elastalert
urllib2
adroitlogic
simd
samba
vala
connectiq
rndis
angular2-pipe
mattermost
opengraph
infinite-loop
uipath
pydub
orgchart
gulp-jshint
dynamic-memory-allocation
paper-trail-gem
boost-log
ios-autolayout
qt-designer
flickr
incapsula
headphones
azure-vm-scale-set
client-server
mockserver
glide-image-library
contenteditable
irr
colorbox
serverless-architecture
google-chrome-storage
easyquery
google-cloud-shell
jcs
cc
dmalloc
printer-control-language
therubyracer
scikit-image
symantec
dbi
class-design
magick.net
android-xmlpullparser
angularfire
queuing
clique
chomsky-normal-form
video-embedding
skbio
oryx
wordpress-plugin
matlab-engine
database-project
taverna
php-gd
tokudb
vspackage
conkeror
drools-guvnor
sync
nsnotificationcenter
system.net
redmine-plugins
serializer
azure-caching
git-subtree
mongoid4
mdp
aspnet-compiler
angularjs-select2
screensharing
freeimage
invalidoperationexception
winrt-async
flatten
datejs
yii-cactiverecord
overlays
sipdroid
lightopenid
lwp
viewdata
chars
web-widget

Resources

Mobile Apps Dev
Database Users
javascript
java
csharp
php
android
MS Developer
developer works
python
ios
c
html
jquery
RDBMS discuss
Cloud Virtualization
Database Dev&Adm
javascript
java
csharp
php
python
android
jquery
ruby
ios
html
Mobile App
Mobile App
Mobile App