Add my custom loss function to torch

2019-02-14 16:54发布

问题:

I want to add a loss function to torch that calculates the edit distance between predicted and target values. Is there an easy way to implement this idea? Or do I have to write my own class with backward and forward functions?

回答1:

If your criterion can be represented as a composition of existing modules and criteria, it's a good idea to simply construct such composition using containers. The only problem is that standard containers are designed to work with modules only, not criteria. The difference is in :forward method signature:

module:forward(input)
criterion:forward(input, target)

Luckily, we are free to define our own container which is able work with criteria too. For example, sequential:

local GeneralizedSequential, _ = torch.class('nn.GeneralizedSequential', 'nn.Sequential')

function GeneralizedSequential:forward(input, target)
    return self:updateOutput(input, target)
end

function GeneralizedSequential:updateOutput(input, target)
    local currentOutput = input
    for i=1,#self.modules do
        currentOutput = self.modules[i]:updateOutput(currentOutput, target)
    end
    self.output = currentOutput
    return currentOutput
end

Below is an illustration of how to implement nn.CrossEntropyCriterion having this generalized sequential container:

function MyCrossEntropyCriterion(weights)
    criterion = nn.GeneralizedSequential()
    criterion:add(nn.LogSoftMax())
    criterion:add(nn.ClassNLLCriterion(weights))
    return criterion
end

Check whether everything is correct:

output = torch.rand(3,3)
target = torch.Tensor({1, 2, 3})

mycrit = MyCrossEntropyCriterion()
-- print(mycrit)
print(mycrit:forward(output, target))
print(mycrit:backward(output, target))

crit = nn.CrossEntropyCriterion()
-- print(crit)
print(crit:forward(output, target))
print(crit:backward(output, target))


回答2:

Just to add to the accepted answer, you have to be careful that the loss function you define (edit distance in your case) is differentiable with respect to the network parameters.



标签: torch