My question is about the TensorFlow method tf.nn.dynamic_rnn
. It returns the output of every time step and the final state.
I would like to know if the returned final state is the state of the cell at the maximum sequence length or if it is determined individually by the sequence_length
argument.
For better understanding an example: I have 3 sequences with length [10,20,30]
and getting back the final state [3,512]
(if the hidden state of the cell has the length 512).
Are the three returned hidden states for the three sequences the state of the cell at time step 30 or am I getting back the states at the time steps [10,20,30]
?
tf.nn.dynamic_rnn
returns two tensors: outputs
and states
.
The outputs
holds the outputs of all cells for all sequences in a batch. So if a particular sequence is shorter and padded with zeros, the outputs
for the last cells will be zero.
The states
holds the last cell state, or equivalently the last non-zero output per sequence (if you're using BasicRNNCell
).
Here's an example:
import numpy as np
import tensorflow as tf
n_steps = 2
n_inputs = 3
n_neurons = 5
X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])
basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)
X_batch = np.array([
# t = 0 t = 1
[[0, 1, 2], [9, 8, 7]], # instance 0
[[3, 4, 5], [0, 0, 0]], # instance 1
])
seq_length_batch = np.array([2, 1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs_val, states_val = sess.run([outputs, states],
feed_dict={X: X_batch, seq_length: seq_length_batch})
print('outputs:')
print(outputs_val)
print('\nstates:')
print(states_val)
This prints something like:
outputs:
[[[-0.85381496 -0.19517037 0.36011398 -0.18617202 0.39162001]
[-0.99998015 -0.99461144 -0.82241321 0.93778896 0.90737367]]
[[-0.99849552 -0.88643843 0.20635395 0.157896 0.76042926]
[ 0. 0. 0. 0. 0. ]]] # because len=1
states:
[[-0.99998015 -0.99461144 -0.82241321 0.93778896 0.90737367]
[-0.99849552 -0.88643843 0.20635395 0.157896 0.76042926]]
Note that the states
holds the same vectors as in output
, and they are the last non-zero outputs per batch instance.