Customize Keras' loss function in a way that t

2020-03-18 06:48发布

I'm working on a multi-label classifier. I have many output labels [1, 0, 0, 1...] where 1 indicates that the input belongs to that label and 0 means otherwise.

In my case the loss function that I use is based on MSE. I want to change the loss function in a way that when the output label is -1 than it will change to the predicted probability of this label.

Check the attached images to best understand what I mean: The scenario is - when the output label is -1 I want the MSE to be equal to zero:

This is the scenario: enter image description here

And in such case I want it to change to:

enter image description here

In such case the MSE of the second label (the middle output) will be zero (this is a special case where I don't want the classifier to learn about this label).

It feels like this is a needed method and I don't really believe that I'm the first to think about it so firstly I wanted to know if there's a name for such way of training Neural Net and second I would like to know how can I do it.

I understand that I need to change some stuff in the loss function but I'm really newbie to Theano and not sure about how to look there for a specific value and how to change the content of the tensor.

1条回答
小情绪 Triste *
2楼-- · 2020-03-18 07:34

I believe this is what you looking for.

import theano
from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential

def customized_loss(y_true, y_pred):
    loss = K.switch(K.equal(y_true, -1), 0, K.square(y_true-y_pred))
    return K.sum(loss)

if __name__ == '__main__':
    model = Sequential([ Dense(3, input_shape=(4,)) ])
    model.compile(loss=customized_loss, optimizer='sgd')

    import numpy as np
    x = np.random.random((1, 4))
    y = np.array([[1,-1,0]])

    output = model.predict(x)
    print output
    # [[ 0.47242549 -0.45106074  0.13912249]]
    print model.evaluate(x, y)  # keras's loss
    # 0.297689884901
    print (output[0, 0]-1)**2 + 0 +(output[0, 2]-0)**2 # double-check
    # 0.297689929093
查看更多
登录 后发表回答