classification


Loss function for class imbalanced binary classifier in Tensor flow


I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like:
minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)
Appreciate any pointers on how I can build this logic.
You can add class weights to the loss function, by multiplying logits.
Regular cross entropy loss is this:
loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
= -x[class] + log(\sum_j exp(x[j]))
in weighted case:
loss(x, class) = weights[class] * (-x[class] + log(\sum_j exp(x[j])))
So by multiplying logits, you are re-scaling predictions of each class by its class weight.
For example:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([ratio, 1.0 - ratio])
logits = ... # shape [batch_size, 2]
weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
xent = tf.nn.softmax_cross_entropy_with_logits(
weighted_logits, labels, name="xent_raw")
For a complete solution, you can use this op in skflow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
The code you proposed seems wrong to me.
The loss should be multiplied by the weight, I agree.
But if you multiply the logit by the class weights, you end with:
weights[class] * -x[class] + log( \sum_j exp(x[j] * weights[class]) )
The second term is not equal to:
weights[class] * log(\sum_j exp(x[j]))
To show this, we can be rewrite the latter as:
log( (\sum_j exp(x[j]) ^ weights[class] )
So here is the code I'm proposing:
ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])
logits = ... # shape [batch_size, 2]
weight_per_label = tf.transpose( tf.matmul(labels
, tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label
xent = tf.mul(weight_per_label
, tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1
Use tf.nn.weighted_cross_entropy_with_logits() and set pos_weight to 1 / (expected ratio of positives).
Did ops tf.nn.weighted_cross_entropy_with_logits() for two classes:
classes_weights = tf.constant([0.1, 1.0])
cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=labels, pos_weight=classes_weights)

Related Links

Different results in Weka GUI and Weka via Java code
imbalanced data classification with boosting algorithms
How to create ARFF file for 2D data points?
How to use weighted vote for classification using weka
Convert Web page to ARFF File for Weka classification
Liblinear bias greater than 2 improving accuracy?
Weka: Does training helps if test run is followed by training run?
Difference between logistic regression with binary output and classification
Weka - How to find input format for classifiers
How to incorporate Weka Naive Bayes model into Java Code
RapidMiner: Classifying new examples without re-running the existing trained model
How to check whether data is being overfiited for that model in weka
Feature Extraction for Face Dectection
rapid-miner formating datsets with many parameter
text classification methods? SVM and decision tree
Multilabel classification with SVM using rapidminer

Categories

HOME
windows-server-2008
opencv4android
snmp
liquidsoap
zabbix
handlebars.js
webpack-dev-server
software-collections
angularjs-ng-repeat
minimum
filehelpers
cucumberjs
loader
keystone
cherrypy
typedef
hololens
rvm
sasl
formtastic
deferred
eclipse-jdt
epub
rtos
openvms
webmock
campaign-monitor
server-side-swift
ui-grid
directadmin
web-technologies
xmldom
riemann
chain
spring-cloud-aws
cardview
esprima
archiva
pebble-js
dcast
nslocalizedstring
overlapping
maven-antrun-plugin
syncano
borland-c++
eggplant
targetprocess
fileinputstream
tinymce-plugins
tidesdk
sicstus-prolog
ultraedit
esri-arc-engine
quickfixn
fxmlloader
sat4j
scjp
behance-api
debuggervisualizer
ant-contrib
rda
listpicker
infinity.js
dnsbl
moonscript
nsregularexpression
invalidoperationexception
inflate
report-viewer2010
event-propagation
nsscrollview
diazo
inserthtml
hwnd
html5-animation
marathontesting
excellibrary
system-analysis
glassfish-embedded
web-widget
ihtmldocument2
expander
urlscan
tracd

Resources

Mobile Apps Dev
Database Users
javascript
java
csharp
php
android
MS Developer
developer works
python
ios
c
html
jquery
RDBMS discuss
Cloud Virtualization
Database Dev&Adm
javascript
java
csharp
php
python
android
jquery
ruby
ios
html
Mobile App
Mobile App
Mobile App