I have a model that has "classification" and "regression" like parts. I merge them using multiplication layer. Before performing multiplication I want to set outputs of classification part to 0 or 1 based on threshold. I tried to use Lambda layer with custom function as below, however I am facing various errors, and I have no clue about those errors. Resolving them one by one as I go does not add to my understanding. Can anyone explain how to define custom Lambda layer function that modifies the values?
My current Lambda layer function: (not working due to FailedPreconditionError: Attempting to use uninitialized value lstm_32/bias
)
def func(x):
a = x.eval(session=tf.Session())
a[x < 0.5] = 0
a[x >= 0.5] = 1
return K.variable(a)
Regression part:
input1 = Input(shape=(1, ))
model = Sequential()
model.add(Embedding(vocab_size + 1, embedding, input_length=1))
model.add(LSTM(hidden, recurrent_dropout=0.1, return_sequences=True))
model.add(LSTM(6))
model.add(Reshape((3,2)))
model.add(Activation('linear'))
Classification part:
input2 = Input(shape=(1, ))
model2 = Sequential()
model2.add(Embedding(vocab_size + 1, embedding, input_length=1))
model2.add(LSTM(hidden, recurrent_dropout=0.1, return_sequences=True))
model2.add(LSTM(1))
model2.add(Activation('sigmoid'))
model2.add(???) # need to add 0-1 thresholding here
Merge two parts:
reg_head = model(input1)
clf_head = model2(input2)
merge_head = multiply(inputs=[clf_head, reg_head])
m2 = Model(inputs=[input1, input2], outputs=merge_head)
In
func
, you cannoteval
the tensors.The idea of using tensors is that they keep a "connection" (a graph, as they call it) from start to end in the entire model. This connection allows the model to compute gradients. If you eval a tensor and try to use these values, you will break the connection.
Also, for taking the actual values of a tensor, you need the input data. And the input data will only exist when you call
fit
,predict
and similar methods. In the build phase there isn't data, only representations and connections.A possible function using only tensors is:
But be careful! This will not be differentiable. These values will be seen as constants from now on in the model. This means that the weights coming before this point will not be updated during training (you won't train the classification model via
m2
, but you will still be able to train it frommodel2
). There are some fancy workarounds for this, if you need them, please write a comment.Use this function in a
Lambda
layer: