Python keras how to transform a dense layer into a

2019-04-25 07:49发布

问题:

I have a problem finding the correct mapping of the weights in order to transform a dense layer into a convolutional layer.

This is an excerpt of a ConvNet that I'm working on:

model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))

After the MaxPooling, the input is of shape (512,7,7). I would like to transform the dense layer into a convolutional layer to make it look like this:

model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Convolution2D(4096, 7, 7, activation='relu'))

However, I don't know how I need to reshape the weights in order to correctly map the flattened weights to the (4096,512,7,7) structure that is needed for the convolutional layer? Right now, the weights of the dense layer are of dimension (25088,4096). I need to somehow map these 25088 elements to a dimension of (512,7,7) while preserving the correct mapping of the weights to the neurons. So far, I have tried multiple ways of reshaping and then transposing but I haven't been able to find the correct mapping.

An example of what I have been trying would be this:

weights[0] = np.transpose(np.reshape(weights[0],(512,7,7,4096)),(3,0,1,2))

but it doesn't map the weights correctly. I verified whether the mapping is correct by comparing the output for both models. If done correctly, I expect the output should be the same.

回答1:

Still looking for solution? Here it is:

new_conv_weights = dense_weights.transpose(1,0).reshape(new_conv_shape)[:,:,::-1,::-1]

in your case:

weights[0] = weights[0].transpose(1,0).reshape((4096,512,7,7))[:,:,::-1,::-1]

The tricky part is conv filters flipping [:,:,::-1,::-1]. Theano does convolution not correlation (unlike caffe e.g.). Hence, in Keras filter like:

1 0
0 0

applied to matrix:

1 2 3 4 5
6 7 8 9 0
1 2 3 4 5

results in matrix:

7 8 9 0 
2 3 4 5

not this, as one would expect with correlation:

1 2 3 4
6 7 8 9

In order to make things working as expected, you need to rotate filters 180 deg. Just solved this problem for myself, hopefully this will be of help for you or for others. Cheers.