Tensorflow save final state of LSTM in dynamic_rnn

2019-05-12 09:11发布

I want to save the final state of my LSTM such that it's included when I restore the model and can be used for prediction. As explained below, the Saver only has knowledge of the final state when I use tf.assign. However, this throws an error (also explained below).

During training I always feed the final LSTM state back into the network, as explained in this post. Here are the important parts of the code:

When building the graph:

            self.init_state = tf.placeholder(tf.float32, [
                self.n_layers, 2, self.batch_size, self.n_hidden
            ])

            state_per_layer_list = tf.unstack(self.init_state, axis=0)

            rnn_tuple_state = tuple([
                tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
                                              state_per_layer_list[idx][1])

                for idx in range(self.n_layers)
            ])

            outputs, self.final_state = tf.nn.dynamic_rnn(
                cell, inputs=self.inputs, initial_state=rnn_tuple_state)

And during training:

        _current_state = np.zeros((self.n_layers, 2, self.batch_size,
                                   self.n_hidden))

            _train_step, _current_state, _loss, _acc, summary = self.sess.run(
                [
                    self.train_step, self.final_state,
                    self.merged
                ],
                feed_dict={self.inputs: _inputs,
                           self.labels:_labels, 
                           self.init_state: _current_state})

When I later restore my model from a checkpoint, the final state is not restored as well. As outlined in this post the problem is that the Saver has no knowledge of the new state. The post also suggests a solution, based on tf.assign. Regrettably, I cannot use the suggested

            assign_op = tf.assign(self.init_state, _current_state)
            self.sess.run(assign_op)

because self.init state is not a Variable but a placeholder. I get the error

AttributeError: 'Tensor' object has no attribute 'assign'

I have tried to solve this problem for several hours now but I can't get it to work.

Any help is appreciated!

EDIT:

I have changed self.init_state to

            self.init_state = tf.get_variable('saved_state', shape=
            [self.n_layers, 2, self.batch_size, self.n_hidden])

            state_per_layer_list = tf.unstack(self.init_state, axis=0)

            rnn_tuple_state = tuple([
                tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
                                              state_per_layer_list[idx][1])

                for idx in range(self.n_layers)
            ])

            outputs, self.final_state = tf.nn.dynamic_rnn(
                cell, inputs=self.inputs, initial_state=rnn_tuple_state)

And during training I don't feed a value for self.init_state:

            _train_step, _current_state, _loss, _acc, summary = self.sess.run(
                [
                    self.train_step, self.final_state,
                    self.merged
                ],
                feed_dict={self.inputs: _inputs,
                           self.labels:_labels})

However, I still can't run the assignment op. Know I get

TypeError: Expected float32 passed to parameter 'value' of op 'Assign', got (LSTMStateTuple(c=array([[ 0.07291573, -0.06366599, -0.23425588, ..., 0.05307654,

1条回答
时光不老,我们不散
2楼-- · 2019-05-12 09:37

In order to save the final state, you can create a separate TF variable, then before saving the graph, run an assign op to assign your latest state to that variable, and then save the graph. The only thing you need to keep in mind is to declare that variable BEFORE you declare the Saver; otherwise it won't be included in the graph.

This is discussed at great detail here, including the working code: TF LSTM: Save State from training session for prediction session later

*** UPDATE: answers to followup questions:

It looks like you are using BasicLSTMCell, with state_is_tuple=True. The prior discussion that I referred you to used GRUCell with state_is_tuple=False. The details between the two are somewhat different, but the overall approach could be similar, so hopefully this should work for you:

During training, you first feed zeros as initial_state into dynamic_rnn and then keep re-feeding its own output back as input as initial_state. So, the LAST output state of our dynamic_rnn call is what you want to save for later. Since it results from a sess.run() call, essentially it's a numpy array (not a tensor and not a placeholder). So the question amounts to "how do I save a numpy array as a Tensorflow variable along with the rest of the variables in the graph." That's why you assign the final state to a variable whose only purpose is that.

So, code is something like this:

    # GRAPH DEFINITIONS:
    state_in = tf.placeholder(tf.float32, [LAYERS, 2, None, CELL_SIZE], name='state_in')
    l = tf.unstack(state_in, axis=0)
    state_tup = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
        for idx in range(NLAYERS)])
    #multicell = your BasicLSTMCell / MultiRNN definitions
    output, state_out = tf.nn.dynamic_rnn(multicell, X, dtype=tf.float32, initial_state=state_tup)

    savedState = tf.get_variable('savedState', shape=[LAYERS, 2, BATCHSIZE, CELL_SIZE])
    saver = tf.train.Saver(max_to_keep=1)

    in_state = np.zeros((LAYERS, 2, BATCHSIZE, CELL_SIZE))

    # TRAINING LOOP:
    feed_dict = {X: x, Y_: y_, batchsize: BATCHSIZE, state_in:in_state}
    _, out_state = sess.run([training_step, state_out], feed_dict=feed_dict)
    in_state = out_state

    # ONCE TRAINING IS OVER:
    assignOp = tf.assign(savedState, out_state)
    sess.run(assignOp)
    saver.save(sess, pathModel + '/my_model.ckpt')

    # RECOVERING IN A DIFFERENT PROGRAM:

    gInit = tf.global_variables_initializer().run()
    lInit = tf.local_variables_initializer().run()
    new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
    new_saver.restore(sess, pathModel + 'my_model.ckpt')
    # retrieve State and get its LAST batch (latest obervarions)
    savedState = sess.run('savedState:0') # this is FULL state from training
    state = savedState[:,:,-1,:]  # -1 gets only the LAST batch of the state (latest seen observations)
    state = np.reshape(state, [state.shape[0], 2, -1, state.shape[2]]) #[LAYERS, 2, 1 (BATCH), SELL_SIZE]
    #x = .... (YOUR INPUTS)
    feed_dict = {'X:0': x, 'state_in:0':state}
    #PREDICTION LOOP:
    preds, state = sess.run(['preds:0', 'state_out:0'], feed_dict = feed_dict)
    # so now state will be re-fed into feed_dict with the next loop iteration

As mentioned, this is a modified approach of what works well for me with GRUCell, where state_is_tuple = False. I adapted it to try BasicLSTMCell with state_is_tuple=True. It works, but not as accurately as the original approach. I don't know yet whether its just because for me GRU is better than LSTM or for some other reason. See if this works for you...

Also keep in mind that, as you can see with the recovery and prediction code, your predictions will likely be based on a different batch size than your training loop (I guess batch of 1?) So you have to think through how to handle your recovered state -- just take the last batch? Or something else? This code takes the last layer of the saved state only (i.e. the most recent observations from training) because that's what was relevant for me...

查看更多
登录 后发表回答