Pytorch loss function dimensions do not match

2019-06-14 19:52发布


I'm trying to run word embeddings using batch training, as shown below.

def forward(self, inputs):
    embeds = self.embeddings(inputs)
    out = self.linear1(embeds)
    out = self.activation_function1(out)
    out = self.linear2(out).cuda()
    out = self.activation_function2(out)
    return out.cuda()

Here, I'm using context size 4, batch size 32, embedding size 50, hidden layer size 64, vocab size 9927

The output of the "shape" functions is

print(inputs.shape) ----> torch.Size([4, 32])

print(embeds.shape) ----> torch.Size([4, 32, 50])

print(out.shape) ----> torch.Size([4, 32, 64])

print(out.shape) ----> torch.Size([4, 32, 64])

print(out.shape) ----> torch.Size([4, 32, 9927])

print(out.shape) ----> torch.Size([4, 32, 9927])

Are the shapes of these correct? I'm quite confused.

Also, when I train, it returns an error:

def train(epoch):
  for batch_idx, (data, target) in enumerate(train_loader, 0):
    output = model(torch.stack(data))
    loss = criterion(output, target)

I'm getting an error in the line "loss = criterion(output, target)". It says "Expected input batch_size (4) to match target batch_size (32)." Are my shapes for the "forward" function correct? I'm not that familiar with batch training. How do I make the dimensions match?

-------EDIT: Posting init code below -----

  def __init__(self, vocab_size, embedding_dim):
    super(CBOW, self).__init__()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear1 = nn.Linear(embedding_dim, 64)
    self.activation_function1 = nn.ReLU()
    self.linear2 = nn.Linear(64, vocab_size)
    self.activation_function2 = nn.LogSoftmax(dim = -1)


torch.nn.Linear's forward method needs batch size as first argument.

You are supplying it as second (first being timesteps), use permute(1, 0, 2) to make them first.

Furthermore, linear layers usually take 2D input, with first being batch and second being dimension of input. Yours is 3d because of words (I assume), maybe you want to use recurrent neural networks (e.g. torch.nn.LSTM)?