Custom linear transformation in keras

2019-07-23 06:20发布

I want to build a customized layer in keras to do a linear transformation on the output of last layer. For example, I got an output X from last layer, my new layer will output X.dot(W)+b.

The shape of W is (49,10), and the shape of X should be (64,49), the shape of b is (10,)

However, the shape of X is (?, 7, 7, 64), when I am trying to reshape it, it becomes shape=(64, ?). What is the meaning of question mark? Could you tell me a proper way to do linear transformation on the output of last layer?

1条回答
女痞
2楼-- · 2019-07-23 06:36

The question mark generally represents the batch size, which has no effect on the model architecture.

You should be able to reshape your X with keras.layers.Reshape((64,49))(X).

You can wrap arbitrary tensorflow operations such as tf.matmul in a Lambda layer to include custom layers in your Keras model. Minimal working example that does the trick:

import tensorflow as tf
from keras.layers import Dense, Lambda, Input
from keras.models import Model

W = tf.random_normal(shape=(128,20))
b = tf.random_normal(shape=(20,))

inp = Input(shape=(10,))
x = Dense(128)(inp)
y = Lambda(lambda x: tf.matmul(x, W) + b)(x)
model = Model(inp, y)
查看更多
登录 后发表回答