I trained my model based on the link here. It achieved almost 90%
accuracy when i trained it. I am using the vgg_bn_drop.lua
model that you will find on the link. But the problem is, i don't know how to test it for a single image.
I know how to test a model. By forward passing the image through the network.
So, testing the model will require modelname:forward(image)
. Where the modelname
is the name of the model that i trained and forward
is used to forward the model and 'image' is the image location that i want to forward. Now, I cannot figure what the dimension of the single image will be from this network.
So, what i want to do is, take an image. Say the image has dimension [3x32x32]. Pass it through the network and get the result. Is it possible with this network?
There was no documentation there how to test it for a single image.
What i tried so far was,
1) Declare a tensor of size (3x32x32). Lets call it image. `image = torch.Tensor(3x32x32). forward pass this.
model:forward(image)
It produces the error ...h/install/share/lua/5.1/nn/SpatialBatchNormalization.lua:68: only mini-batch supported (4D tensor), got 3D tensor instead
2) I reshaped the image to (1,3,32,32)
image = image:reshape(1,3,32,32)
forward pass this
model:forward(image)
It produces the error
...ch/torch/install/share/lua/5.1/nn/BatchNormalization.lua:67: only mini-batch supported (2D tensor), got 1D tensor instead
So I tried out approaches. But could not figure out how to pass a single image to that network. Can you help me out?
The model definition is
require 'nn'
local vgg = nn.Sequential()
-- building block
local function ConvBNReLU(nInputPlane, nOutputPlane)
vgg:add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 3,3, 1,1, 1,1))
vgg:add(nn.SpatialBatchNormalization(nOutputPlane,1e-3))
vgg:add(nn.ReLU(true))
return vgg
end
-- Will use "ceil" MaxPooling because we want to save as much feature space as we can
local MaxPooling = nn.SpatialMaxPooling
ConvBNReLU(3,64):add(nn.Dropout(0.3))
ConvBNReLU(64,64)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(64,128):add(nn.Dropout(0.4))
ConvBNReLU(128,128)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(128,256):add(nn.Dropout(0.4))
ConvBNReLU(256,256):add(nn.Dropout(0.4))
ConvBNReLU(256,256)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(256,512):add(nn.Dropout(0.4))
ConvBNReLU(512,512):add(nn.Dropout(0.4))
ConvBNReLU(512,512)
vgg:add(MaxPooling(2,2,2,2):ceil())
ConvBNReLU(512,512):add(nn.Dropout(0.4))
ConvBNReLU(512,512):add(nn.Dropout(0.4))
ConvBNReLU(512,512)
vgg:add(MaxPooling(2,2,2,2):ceil())
vgg:add(nn.View(512))
vgg:add(nn.Dropout(0.5))
vgg:add(nn.Linear(512,512))
vgg:add(nn.BatchNormalization(512))
vgg:add(nn.ReLU(true))
vgg:add(nn.Dropout(0.5))
vgg:add(nn.Linear(512,10))
-- initialization from MSR
local function MSRinit(net)
local function init(name)
for k,v in pairs(net:findModules(name)) do
local n = v.kW*v.kH*v.nOutputPlane
v.weight:normal(0,math.sqrt(2/n))
v.bias:zero()
end
end
init'nn.SpatialConvolution'
end
MSRinit(vgg)
return vgg
Well, the error is clear:
nn.BatchNormalization
expects a 2D tensor as an input (a batch), but receives a 1D tensor. You added batch dimension to your input (image:reshape(1,3,32,32)
), but passing through your network, the dimension was lost.nn.View
module is guilty of this.Suppose the module was instantiated with the following parameter:
and it is given an input tensor of shape
batch_size x channels x height x width
(1x512x1x1). The module now has to decide whether it is expected to return a batch or a single non-batch output.batch_size
> 1, the answer is obvious:batch_size*channels*height*width
is a multiple ofoutput_size
=> the input is a batch => the output must be a batch.batch_size
== 1, what then?1*channels*height*width
==output_size
, is the input a batch or not?nn.View
assumes it's not and produces a single output (without batch dimension).To fix the misunderstanding, one can specify the number
NB
of non-batch dimensions (if input has NB+1 dimensions, it's a batch):In light of the above, this will solve your problem:
vgg:add(nn.View(512):setNumInputDims(3))