Mini batch training for inputs of variable sizes

2019-07-08 09:19发布

问题:

I have a list of LongTensors, and another list of labels. I'm new to PyTorch and RNN's so I'm quite confused as to how to implement minibatch training for the data I have. There is much more to this data, but I want to keep it simple, so I can understand only how to implement the minibatch training part. I'm doing multiclass classification based on the final hidden state of an LSTM/GRU trained on variable length inputs. I managed to get it working with batch size 1(basically SGD) but I'm struggling with implementing minibatches.

Do I have to pad the sequences to the maximum size and create a new tensor matrix of larger size which holds all the elements? I mean like this:

inputs = pad(sequences)
train = DataLoader(inputs, batch_size=batch_size, shuffle=True)
for i, data in train:
   #do stuff using LSTM and/or GRU models

Is this the accepted way of doing minibatch training on custom data? I couldn't find any tutorials on loading custom data using DataLoader(but I assume that's the way to create batches using pyTorch?)

Another doubt I have is with regards to padding. The reason I'm using LSTM/GRU is because of the variable length of the input. Doesn't padding defeat the purpose? Is padding necessary for minibatch training?

回答1:

Yes. The issue with minibatch training on sequences which have different lengths is that you can't stack sequences of different lengths together.

Normally one would do.

for e in range(epochs):
    sequences = shuffle(sequences)
    for mb in range(len(sequences)/mb_size):
        batch = torch.stack(sequences[mb*mb_size:(mb+1)*mb_size])

and then you apply your neural network on your batch. But because your sequences are of different lengths, the torch.stack will fail. So indeed what you have to do is to pad your sequences with zeros so that they all have the same length (at least in a minibatch). So you have 2 options:

1) At the very very beginning, pad all your sequences with initial zeros so that they all have the same length as your longest sequence of all your data.

OR

2) On the fly, for each minibatch, before stacking the sequences together, pad all the sequences that will go into the minibatch with initial zeros so that they all have the same length as the longest sequence of the minibatch.