-->

Tensorflow RNN-LSTM - reset hidden state

2020-03-27 14:49发布

问题:

I'm building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a discussion. However for the network to be properly trained I need to reset the hidden state of the LSTM between some batches.

I'm using a variable to store the hidden_state of the LSTM for performance :

    with tf.variable_scope('Hidden_state'):
        hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
                                       tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
        # Arrange it to a tuple of LSTMStateTuple as needed
        l = tf.unstack(hidden_state, axis=0)
        rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
                                for idx in range(self.num_layers)])

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
                                          initial_state=rnn_tuple_state, time_major=True)

Now I'm confused on how to reset the hidden state. I've tried two solutions but it's not working :

First solution

Reset the "hidden_state" variable with :

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))

It does work and I think it's because the unstack and tuple construction are not "re-played" into the graph after running the rnn_state_zero_op operation.

Second solution

Following LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow I tried to reset the cell state with :

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)

It doesn't seem to work either.

Question

I've another solution in mind but it's guessing at best : I'm not keeping the state returned by tf.nn.dynamic_rnn, I've thought of it but I get a tuple and I can't find a way to build an op to reset the tuple.

At this point I've to admit that I don't quite understand the internal working of tensorflow and if it's even possible to do what I'm trying to do. Is there a proper way to do it ?

Thanks !

回答1:

Thanks to this answer to another question I was able to find a way to have complete control on whether or not (and when) the internal state of the RNN should be reset to 0.

First you need to define some variables to store the state of the RNN, this way you will have control over it :

with tf.variable_scope('Hidden_state'):
    state_variables = []
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
        state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    rnn_tuple_state = tuple(state_variables)

Note that this version define directly the variables used by the LSTM, this is much better than the version in my question because you don't have to unstack and build the tuple, which add some ops to the graph that you cannot run explicitly.

Secondly build the RNN and retrieve the final state :

# Build the RNN
with tf.name_scope('LSTM'):
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                               sequence_length=input_seq_lengths,
                                               initial_state=rnn_tuple_state,
                                               time_major=True)

So now you have the new internal state of the RNN. You can define two ops to manage it.

The first one will update the variables for the next batch. So in the next batch the "initial_state" of the RNN will be fed with the final state of the previous batch :

# Define an op to keep the hidden state between batches
update_ops = []
for state_variable, new_state in zip(rnn_tuple_state, new_states):
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(new_state[0]),
                       state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_keep_state_op = tf.tuple(update_ops)

You should add this op to your session anytime you want to run a batch and keep the internal state.

Beware : if you run batch 1 with this op called then batch 2 will start with the batch 1 final state, but if you don't call it again when running batch 2 then batch 3 will start with batch 1 final state also. My advice is to add this op every time you run the RNN.

The second op will be used to reset the internal state of the RNN to zeros:

# Define an op to reset the hidden state to zeros
update_ops = []
for state_variable in rnn_tuple_state:
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                       state_variable[1].assign(tf.zeros_like(state_variable[1]))])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_state_zero_op = tf.tuple(update_ops)

You can call this op whenever you want to reset the internal state.



回答2:

Simplified version of AMairesse post for one LSTM layer:

zero_state = tf.zeros(shape=[1, units[-1]])
self.c_state = tf.Variable(zero_state, trainable=False)
self.h_state = tf.Variable(zero_state, trainable=False)
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)

self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)

# save or reset states
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]

or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder):

self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")

self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
  true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
  false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))