I want to save the final state of my LSTM such that it's included when I restore the model and can be used for prediction. As explained below, the Saver only has knowledge of the final state when I use tf.assign
. However, this throws an error (also explained below).
During training I always feed the final LSTM state back into the network, as explained in this post. Here are the important parts of the code:
When building the graph:
self.init_state = tf.placeholder(tf.float32, [
self.n_layers, 2, self.batch_size, self.n_hidden
])
state_per_layer_list = tf.unstack(self.init_state, axis=0)
rnn_tuple_state = tuple([
tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
state_per_layer_list[idx][1])
for idx in range(self.n_layers)
])
outputs, self.final_state = tf.nn.dynamic_rnn(
cell, inputs=self.inputs, initial_state=rnn_tuple_state)
And during training:
_current_state = np.zeros((self.n_layers, 2, self.batch_size,
self.n_hidden))
_train_step, _current_state, _loss, _acc, summary = self.sess.run(
[
self.train_step, self.final_state,
self.merged
],
feed_dict={self.inputs: _inputs,
self.labels:_labels,
self.init_state: _current_state})
When I later restore my model from a checkpoint, the final state is not restored as well. As outlined in this post the problem is that the Saver has no knowledge of the new state. The post also suggests a solution, based on tf.assign
. Regrettably, I cannot use the suggested
assign_op = tf.assign(self.init_state, _current_state)
self.sess.run(assign_op)
because self.init state is not a Variable but a placeholder. I get the error
AttributeError: 'Tensor' object has no attribute 'assign'
I have tried to solve this problem for several hours now but I can't get it to work.
Any help is appreciated!
EDIT:
I have changed self.init_state to
self.init_state = tf.get_variable('saved_state', shape=
[self.n_layers, 2, self.batch_size, self.n_hidden])
state_per_layer_list = tf.unstack(self.init_state, axis=0)
rnn_tuple_state = tuple([
tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
state_per_layer_list[idx][1])
for idx in range(self.n_layers)
])
outputs, self.final_state = tf.nn.dynamic_rnn(
cell, inputs=self.inputs, initial_state=rnn_tuple_state)
And during training I don't feed a value for self.init_state:
_train_step, _current_state, _loss, _acc, summary = self.sess.run(
[
self.train_step, self.final_state,
self.merged
],
feed_dict={self.inputs: _inputs,
self.labels:_labels})
However, I still can't run the assignment op. Know I get
TypeError: Expected float32 passed to parameter 'value' of op 'Assign', got (LSTMStateTuple(c=array([[ 0.07291573, -0.06366599, -0.23425588, ..., 0.05307654,
In order to save the final state, you can create a separate TF variable, then before saving the graph, run an
assign
op to assign your latest state to that variable, and then save the graph. The only thing you need to keep in mind is to declare that variable BEFORE you declare theSaver
; otherwise it won't be included in the graph.This is discussed at great detail here, including the working code: TF LSTM: Save State from training session for prediction session later
*** UPDATE: answers to followup questions:
It looks like you are using
BasicLSTMCell
, withstate_is_tuple=True
. The prior discussion that I referred you to usedGRUCell
withstate_is_tuple=False
. The details between the two are somewhat different, but the overall approach could be similar, so hopefully this should work for you:During training, you first feed zeros as
initial_state
intodynamic_rnn
and then keep re-feeding its own output back as input asinitial_state
. So, the LAST output state of ourdynamic_rnn
call is what you want to save for later. Since it results from asess.run()
call, essentially it's a numpy array (not a tensor and not a placeholder). So the question amounts to "how do I save a numpy array as a Tensorflow variable along with the rest of the variables in the graph." That's why you assign the final state to a variable whose only purpose is that.So, code is something like this:
As mentioned, this is a modified approach of what works well for me with
GRUCell
, wherestate_is_tuple = False
. I adapted it to tryBasicLSTMCell
withstate_is_tuple=True
. It works, but not as accurately as the original approach. I don't know yet whether its just because for me GRU is better than LSTM or for some other reason. See if this works for you...Also keep in mind that, as you can see with the recovery and prediction code, your predictions will likely be based on a different batch size than your training loop (I guess batch of 1?) So you have to think through how to handle your recovered state -- just take the last batch? Or something else? This code takes the last layer of the saved state only (i.e. the most recent observations from training) because that's what was relevant for me...