I have a 3-D tensor of shape [batch, None, dim]
where the second dimension, i.e. the timesteps, is unknown. I use dynamic_rnn
to process such input, like in the following snippet:
import numpy as np
import tensorflow as tf
batch = 2
dim = 3
hidden = 4
lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
Actually, running this snipped with some actual numbers, I have some reasonable results:
inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
lengths_ = np.asarray([3, 1], dtype=np.int32)
with tf.Session() as sess:
output_ = sess.run(output, {inputs: inputs_, lengths: lengths_})
And the output is:
[[[ 0. 0. 0. 0. ]
[ 0.02188676 -0.01294564 0.05340237 -0.47148666]
[ 0.0343586 -0.02243731 0.0870839 -0.89869428]
[ 0. 0. 0. 0. ]]
[[ 0.00284752 -0.00315077 0.00108094 -0.99883419]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
Is there a way to get a 3-D tensor of shape [batch, 1, hidden]
with the last relevant output of the dynamic RNN? Thanks!
This is what gather_nd is for!
def extract_axis_1(data, ind):
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.gather_nd(data, indices)
return res
In your case:
output = extract_axis_1(output, lengths - 1)
Now output
is a tensor of dimension [batch_size, num_cells]
From the following two sources,
outputs, last_states = tf.nn.dynamic_rnn(
Or https://github.com/ageron/handson-ml/blob/master/14_recurrent_neural_networks.ipynb,
It is clear the last_states can be directly extracted from the SECOND output of the dynamic_rnn call. It will give you the last_states across all layers (in LSTM it is compsed from LSTMStateTuple) , while the outputs contains all the states in the last layer.
Okay — so, looks like there actually is an easier solution. As @Shao Tang and @Rahul mentioned, the preferred way to do this would be by accessing the final cell state. Here’s why:
- If you look at the GRUCell source code (below), you’ll see that the “state” that the cell maintains is actually the hidden weights themselves. So, when the
returns the final state, it is actually returning the final hidden weights that you are interested in. To prove this, I just tweaked your setup and got the results:
GRUCell Call (rnn_cell_impl.py):
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
if self._gate_linear is None:
bias_ones = self._bias_initializer
if self._bias_initializer is None:
bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype)
with vs.variable_scope("gates"): # Reset gate and update gate.
self._gate_linear = _Linear(
[inputs, state],
2 * self._num_units,
value = math_ops.sigmoid(self._gate_linear([inputs, state]))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
r_state = r * state
if self._candidate_linear is None:
with vs.variable_scope("candidate"):
self._candidate_linear = _Linear(
[inputs, r_state],
c = self._activation(self._candidate_linear([inputs, r_state]))
new_h = u * state + (1 - u) * c
return new_h, new_h
import numpy as np
import tensorflow as tf
batch = 2
dim = 3
hidden = 4
lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, state = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
lengths_ = np.asarray([3, 1], dtype=np.int32)
with tf.Session() as sess:
output_, state_ = sess.run([output, state], {inputs: inputs_, lengths: lengths_})
print (output_)
print (state_)
[[[ 0. 0. 0. 0. ]
[-0.24305521 -0.15512943 0.06614969 0.16873555]
[-0.62767833 -0.30741733 0.14819752 0.44313088]
[ 0. 0. 0. 0. ]]
[[-0.99152333 -0.1006391 0.28767768 0.76360202]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
[[-0.62767833 -0.30741733 0.14819752 0.44313088]
[-0.99152333 -0.1006391 0.28767768 0.76360202]]
For other readers who are working with the LSTMCell (another popular option), things work a little differently. The LSTMCell maintains the state in a different way - cell state is either a tuple or a concatenated version of the actual cell state and the hidden state. So, to access the final hidden weights, you could set (is_state_tuple
to True
) during cell-initialization, and the final state will be a tuple : (final cell state, final hidden weights). So, in this case,
_, (_, h) = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
will give you the final weights.
c_state and m_state in Tensorflow LSTM
Actually, the solution was not that hard. I implemented the following code:
slices = []
for index, l in enumerate(tf.unstack(lengths)):
slice = tf.slice(rnn_out, begin=[index, l - 1, 0], size=[1, 1, 3])
last = tf.concat(0, slices)
So, the full snippet would be the following:
import numpy as np
import tensorflow as tf
batch = 2
dim = 3
hidden = 4
lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, _ = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
lengths_ = np.asarray([3, 1], dtype=np.int32)
slices = []
for index, l in enumerate(tf.unstack(lengths)):
slice = tf.slice(output, begin=[index, l - 1, 0], size=[1, 1, 3])
last = tf.concat(0, slices)
with tf.Session() as sess:
outputs = sess.run([output, last], {inputs: inputs_, lengths: lengths_})
print 'RNN output:'
print 'last relevant output:'
And the output:
RNN output:
[[[ 0. 0. 0. 0. ]
[-0.06667092 -0.09284072 0.01098599 -0.03676109]
[-0.09101103 -0.19828682 0.03546784 -0.08721405]
[ 0. 0. 0. 0. ]]
[[-0.00025157 -0.05704876 0.05527233 -0.03741353]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
last relevant output:
[[[-0.09101103 -0.19828682 0.03546784]]
[[-0.00025157 -0.05704876 0.05527233]]]