TensorFlow dynamic_rnn state

2020-02-26 12:41发布

问题:

My question is about the TensorFlow method tf.nn.dynamic_rnn. It returns the output of every time step and the final state.

I would like to know if the returned final state is the state of the cell at the maximum sequence length or if it is determined individually by the sequence_length argument.

For better understanding an example: I have 3 sequences with length [10,20,30] and getting back the final state [3,512] (if the hidden state of the cell has the length 512).

Are the three returned hidden states for the three sequences the state of the cell at time step 30 or am I getting back the states at the time steps [10,20,30] ?

回答1:

tf.nn.dynamic_rnn returns two tensors: outputs and states.

The outputs holds the outputs of all cells for all sequences in a batch. So if a particular sequence is shorter and padded with zeros, the outputs for the last cells will be zero.

The states holds the last cell state, or equivalently the last non-zero output per sequence (if you're using BasicRNNCell).

Here's an example:

import numpy as np
import tensorflow as tf

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
])
seq_length_batch = np.array([2, 1])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print('outputs:')
  print(outputs_val)
  print('\nstates:')
  print(states_val)

This prints something like:

outputs:
[[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]
  [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]]

 [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]
  [ 0.          0.          0.          0.          0.        ]]]  # because len=1

states:
[[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]
 [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]

Note that the states holds the same vectors as in output, and they are the last non-zero outputs per batch instance.