Python keras how to transform a dense layer into a

2019-04-25 06:59发布

I have a problem finding the correct mapping of the weights in order to transform a dense layer into a convolutional layer.

This is an excerpt of a ConvNet that I'm working on:

model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))

After the MaxPooling, the input is of shape (512,7,7). I would like to transform the dense layer into a convolutional layer to make it look like this:

model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Convolution2D(4096, 7, 7, activation='relu'))

However, I don't know how I need to reshape the weights in order to correctly map the flattened weights to the (4096,512,7,7) structure that is needed for the convolutional layer? Right now, the weights of the dense layer are of dimension (25088,4096). I need to somehow map these 25088 elements to a dimension of (512,7,7) while preserving the correct mapping of the weights to the neurons. So far, I have tried multiple ways of reshaping and then transposing but I haven't been able to find the correct mapping.

An example of what I have been trying would be this:

weights[0] = np.transpose(np.reshape(weights[0],(512,7,7,4096)),(3,0,1,2))

but it doesn't map the weights correctly. I verified whether the mapping is correct by comparing the output for both models. If done correctly, I expect the output should be the same.

1条回答
混吃等死
2楼-- · 2019-04-25 07:56

Still looking for solution? Here it is:

new_conv_weights = dense_weights.transpose(1,0).reshape(new_conv_shape)[:,:,::-1,::-1]

in your case:

weights[0] = weights[0].transpose(1,0).reshape((4096,512,7,7))[:,:,::-1,::-1]

The tricky part is conv filters flipping [:,:,::-1,::-1]. Theano does convolution not correlation (unlike caffe e.g.). Hence, in Keras filter like:

1 0
0 0

applied to matrix:

1 2 3 4 5
6 7 8 9 0
1 2 3 4 5

results in matrix:

7 8 9 0 
2 3 4 5

not this, as one would expect with correlation:

1 2 3 4
6 7 8 9

In order to make things working as expected, you need to rotate filters 180 deg. Just solved this problem for myself, hopefully this will be of help for you or for others. Cheers.

查看更多
登录 后发表回答