Tensorflow raw_rnn retrieve tensor of shape BATCH

2019-08-17 00:45发布

问题:

I am implementing encoder-decoder lstm, where I have to do custom computation at each step of the encoder. So, I am using raw_rnn. However, I am facing a problem accessing an element from the embeddings which is shaped as Batch x Time steps x Embedding dimensionality at time step time.

Here is my setup:

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)

The main part:

with tf.variable_scope('ReaderNetwork'):
def loop_fn_initial():
    init_elements_finished = (0 >= encoder_inputs_length)
    init_input = cell.zero_state(batch_size, tf.float32)
    init_cell_state = None
    init_cell_output = None
    init_loop_state = None
    return (init_elements_finished, init_input,
            init_cell_state, init_cell_output, init_loop_state)


def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
    def get_next_input():
        # **TODO** read tensor of shape BATCH X EMBEDDING_DIM from encoder_inputs_embedded
        #  which has shape BATCH x TIME_STEPS x EMBEDDING_DIM

    elements_finished = (time >= encoder_inputs_length)
    finished = tf.reduce_all(elements_finished)  # boolean scalar
    input_val = tf.cond(finished,
                        true_fn=lambda: tf.zeros([batch_size, input_embedding_size]), false_fn=get_next_input)
    state = previous_state
    output = previous_output
    loop_state = None
    return elements_finished, input_val, state, output, loop_state


def loop_fn(time, previous_output, previous_state, previous_loop_state):
    if previous_state is None:  # time = 0
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

The running part:

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()

def next_batch():
    return {
        encoder_inputs: np.random.random((batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size
    }

init = tf.global_variables_initializer()
with tf.Session() as s:
    s.run(init)
    outs = s.run([outputs], feed_dict=next_batch())
    print len(outs), outs[0].shape

Question: How to access part of the embeddings at a time step and return a tensor of shape batch x embedding dim? See function get_next_input within loop_fn_transition.

Thank you.

回答1:

I was able fix the problem. Since embeddings have shape Batch x Time steps x Embedding dimensionality I slice out on time dimension. The resulting tensor has shape (?, embedding dimensionality). It is also required to explicitly set the shape of the resulting tensor in order to avoid the error:

ValueError: The shape for rnn/while/Merge_2:0 is not an invariant for the loop

Here is the relevant part:

def get_next_input():
    embedded_value = encoder_inputs_embedded[:, time, :]
    embedded_value.set_shape([batch_size, input_embedding_size])
    return embedded_value

Can anyone confirm if this is the right way to solve the problem?

Here is the complete code for reference:

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')
go_time_slice = tf.ones([batch_size], dtype=tf.int32, name='GO') * 1
go_step_embedded = tf.nn.embedding_lookup(embeddings, go_time_slice)


with tf.variable_scope('ReaderNetwork'):
    def loop_fn_initial():
        init_elements_finished = (0 >= encoder_inputs_length)
        init_input = go_step_embedded
        init_cell_state = cell.zero_state(batch_size, tf.float32)
        init_cell_output = None
        init_loop_state = None
        return (init_elements_finished, init_input,
                init_cell_state, init_cell_output, init_loop_state)

    def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
        def get_next_input():
            embedded_value = encoder_inputs_embedded[:, time, :]
            embedded_value.set_shape([batch_size, input_embedding_size])
            return embedded_value

        elements_finished = (time >= encoder_inputs_length)
        finished = tf.reduce_all(elements_finished)  # boolean scalar
        next_input = tf.cond(finished,
                             true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
                             false_fn=get_next_input)
        state = previous_state
        output = previous_output
        loop_state = None
        return elements_finished, next_input, state, output, loop_state


    def loop_fn(time, previous_output, previous_state, previous_loop_state):
        if previous_state is None:  # time = 0
            return loop_fn_initial()
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()


def next_batch():
    return {
        encoder_inputs: np.random.randint(0, vocab_size, (batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size
    }


init = tf.global_variables_initializer()
with tf.Session() as s:
    s.run(init)
    outs = s.run([outputs], feed_dict=next_batch())
    print len(outs), outs[0].shape