How to do weighted softmax output custom op in mxn

2019-04-10 14:02发布

I want to replace mx.symbol.SoftmaxOutput with the weighted version (assign different weight respect to label's frequency in the whole dataset)

The original function works well like below:

cls_prob = mx.symbol.SoftmaxOutput(data=data,
                                   label=label,
                                   multi_output=True,
                                   normalization='valid',
                                   use_ignore=True, 
                                   ignore_label=-1,
                                   name='cls_prob')

The current code I wrote as below. The code can run without errors, but the loss quickly explode to nan. I am dealing with detection problem, RCNNL1 loss with quickly become nan when I use my code as CustomOp. Another thing is that I have to ignore label -1 and I am not sure how to do it properly. Any help will be greatly appreciated.

import mxnet as mx
import numpy as np

class WeightedSoftmaxCrossEntropyLoss(mx.operator.CustomOp):
    def __init__(self, num_class):
        self.num_class = int(num_class)

    def forward(self, is_train, req, in_data, out_data, aux):

        data = in_data[0]
        label = in_data[1]
        pred = mx.nd.SoftmaxOutput(data, label, multi_output=True,
                               normalization='valid', use_ignore=True, ignore_label=-1,
                               name='rcnn_cls_prob')

        self.assign(out_data[0], req[0], pred)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        cls_weight = np.array([
            0.002852781814876101, 
            0.30715984513157385, 
            1.0932468996115976, 
            1.1598757152765971, 
            0.20739109264009636, 
            1.1984256112776808, 
            0.18746186040248036, 
            2.9009928470737023, 
            0.92140970338602113, 
            1.200317380251021
        ])
        label = in_data[1]
        pred = out_data[0]
        label = label.asnumpy().astype('int32').reshape((-1))
        pred = pred.asnumpy().reshape((pred.shape[0], pred.shape[1], -1)).transpose((0, 2, 1))
        pred = pred.reshape((label.shape[0], -1))

        # Need to ignore label (how)
        out_inds = np.where(label == -1)[0]
        #label = label[keep_inds]
        one_hot = np.zeros((label.shape[0], self.num_class))
        one_hot[np.arange(label.shape[0]), label] = 1
        # gradient
        dx = pred - one_hot
        #dx[out_inds] = 0.0
        weighted_dx = cls_weight * dx / 4
        self.assign(in_grad[0], req[0], weighted_dx)

@mx.operator.register("weighted_softmax_ce_loss")
class WeightedSoftmaxCrossEntropyLossProp(mx.operator.CustomOpProp):
    def __init__(self, num_class):
        super(WeightedSoftmaxCrossEntropyLossProp, self).__init__(need_top_grad=False)
        self.num_class = num_class

    def list_arguments(self):
        return ['data', 'label']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        label_shape = (in_shapes[0][0],)
        output_shape = in_shapes[0]
        return [data_shape, label_shape], [output_shape], []

    def create_operator(self, ctx, in_shapes, in_dtypes):
        #  create and return the CustomOp class.
        `enter code here`return WeightedSoftmaxCrossEntropyLoss(self.num_class)

标签: python mxnet
1条回答
乱世女痞
2楼-- · 2019-04-10 14:19

I am not sure if using a customop here would be the best as it may be slow. Because SoftmaxOuput computes the gradient in the backward pass, it is not convenient to multiply losses as you want to do. However, it is not too complicated to do with the symbolic API. I have attached a toy example, hope it helps.

import mxnet as mx
import numpy as np
import logging

# learn floor function from random numbers in [-1, -1 + num_classes]
n = 10000
batch_size = 128
num_classes = 10
x = (np.random.random((n,)) * num_classes) - 1
y = np.floor(x)
print(x[:2])
print(y[:2])

# define graph
data = mx.symbol.Variable('data')
label = mx.symbol.Variable('label')
class_weights = mx.symbol.Variable('class_weights')
fc = mx.sym.FullyConnected(data=data, num_hidden=num_classes)
fc = mx.sym.Activation(data=fc, act_type='relu')
proba = mx.sym.FullyConnected(data=fc, num_hidden=num_classes)
proba = mx.sym.softmax(proba)

# multipy cross entropy loss by weight
cross_entropy = -mx.sym.pick(proba, label) * mx.sym.pick(class_weights, label)

# mask the loss to zero when label is -1
mask = mx.sym.broadcast_not_equal(label, mx.sym.ones_like(label) * -1)
cross_entropy = cross_entropy * mask

# fit module
class_weights = np.array([np.arange(1, 1 + num_classes)]*n) 
data_iter = mx.io.NDArrayIter(data={'data': x, 'class_weights': class_weights}, label={'label': y}, batch_size=batch_size)
mod = mx.mod.Module(
    mx.sym.Group([mx.sym.MakeLoss(cross_entropy, name='ce_loss'), mx.sym.BlockGrad(proba)]),
    data_names=[v.name for v in data_iter.provide_data],
    label_names=[v.name for v in data_iter.provide_label]
)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
mod.bind(data_shapes=data_iter.provide_data, label_shapes=data_iter.provide_label)
mod.init_params()
mod.fit(
    data_iter, 
    num_epoch=200, 
    optimizer=mx.optimizer.Adam(learning_rate=0.01, rescale_grad=1.0/batch_size),
    batch_end_callback=mx.callback.Speedometer(batch_size, 200), 
    eval_metric=mx.metric.Loss(name="loss", output_names=["ce_loss_output"]))

# show result, -1 are not predicted correctly as we did not compute their loss
probas = mod.predict(data_iter)[1].asnumpy()
print(zip(x, np.argmax(probas, axis=1)))
查看更多
登录 后发表回答