Using a stateful Keras model in pure TensorFlow

2019-08-07 10:32发布

问题:

I have a stateful RNN model with several GRU layers that was created in Keras.

I have to run this model now from Java, so I dumped the model as protobuf, and I'm loading it from Java TensorFlow.

This model must be stateful because features will be fed one timestep at-a-time.

As far as I understand, in order to achieve statefulness in a TensorFlow model, I must somehow feed in the last state every time I execute the session runner, and also that the run would return the state after the execution.

  1. Is there a way to output the state in the Keras model?
  2. Is there a simpler way altogether to get a stateful Keras model to work as such using TensorFlow?

Many thanks

回答1:

An alternative solution is to use the model.state_updates property of the keras model, and add it to the session.run call.

Here is a full example that illustrates this solutions with two lstms:

import tensorflow as tf

class SimpleLstmModel(tf.keras.Model):
    """ Simple lstm model with two lstm """
    def __init__(self, units=10, stateful=True):
        super(SimpleLstmModel, self).__init__()
        self.lstm_0 = tf.keras.layers.LSTM(units=units, stateful=stateful, return_sequences=True)
        self.lstm_1 = tf.keras.layers.LSTM(units=units, stateful=stateful, return_sequences=True)

    def call(self, inputs):
        """
        :param inputs: [batch_size, seq_len, 1]
        :return: output tensor
        """
        x = self.lstm_0(inputs)
        x = self.lstm_1(x)
        return x

def main():
    model = SimpleLstmModel(units=1, stateful=True)
    x = tf.placeholder(shape=[1, 1, 1], dtype=tf.float32)
    output = model(x)
    sess = tf.Session()

    sess.run(tf.initialize_all_variables())

    res_at_step_1, _ = sess.run([output, model.state_updates], feed_dict={x: [[[0.1]]]})
    print(res_at_step_1)
    res_at_step_2, _ = sess.run([output, model.state_updates], feed_dict={x: [[[0.1]]]})
    print(res_at_step_2)




if __name__ == "__main__":
    main()

Which produces the following output:

[[[0.00168626]]] [[[0.00434444]]]

and shows that the lstm state is preserved between batches. If we set stateful to False, the output becomes:

[[[0.00033928]]] [[[0.00033928]]]

Showing that the state is not reused.



回答2:

ok, so I managed to solve this problem!

What worked for me was creating tf.identity tensors for not only the outputs, as is standard, but also for the state tensors.

In the Keras models, the state tensors can be found by doing:

model.updates

Which gives something like this:

[(<tf.Variable 'gru_1_1/Variable:0' shape=(1, 70) dtype=float32_ref>,
  <tf.Tensor 'gru_1_1/while/Exit_2:0' shape=(1, 70) dtype=float32>),
 (<tf.Variable 'gru_2_1/Variable:0' shape=(1, 70) dtype=float32_ref>,
  <tf.Tensor 'gru_2_1/while/Exit_2:0' shape=(1, 70) dtype=float32>),
 (<tf.Variable 'gru_3_1/Variable:0' shape=(1, 4) dtype=float32_ref>,
  <tf.Tensor 'gru_3_1/while/Exit_2:0' shape=(1, 4) dtype=float32>)]

The 'Variable' is used for inputting the states, and the 'Exit' for outputs of the new states. So I created tf.identity out of the 'Exit' tensors. I gave them meaningful names, e.g.:

tf.identity(state_variables[j], name='state'+str(j))

Where state_variables contained only the 'Exit' tensors

Then used the input variables (e.g. gru_1_1/Variable:0) to feed the model state from TensorFlow, and the identity variables I created out of the 'Exit' tensors were used to extract the new states after feeding the model at each timestep