What does the “source hidden state” refer to in th

2019-07-03 18:22发布

The attention weights are computed as:

enter image description here

I want to know what the h_s refers to.

In the tensorflow code, the encoder RNN returns a tuple:

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(...)

As I think, the h_s should be the encoder_state, but the github/nmt gives a different answer?

# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=source_sequence_length)

Did I misunderstand the code? Or the h_s actually means the encoder_outputs?

1条回答
成全新的幸福
2楼-- · 2019-07-03 18:32

The formula is probably from this post, so I'll use a NN picture from the same post:

nn

Here, the h-bar(s) are all the blue hidden states from the encoder (the last layer), and h(t) is the current red hidden state from the decoder (also the last layer). One the picture t=0, and you can see which blocks are wired to the attention weights with dotted arrows. The score function is usually one of those:

formula


Tensorflow attention mechanism matches this picture. In theory, cell output is in most cases its hidden state (one exception is LSTM cell, in which the output is the short-term part of the state, and even in this case the output suits better for attention mechanism). In practice, tensorflow's encoder_state is different from encoder_outputs when the input is padded with zeros: the state is propagated from the previous cell state while the output is zero. Obviously, you don't want to attend to trailing zeros, so it makes sense to have h-bar(s) for these cells.

So encoder_outputs are exactly the arrows that go from the blue blocks upward. Later in a code, attention_mechanism is connected to each decoder_cell, so that its output goes through the context vector to the yellow block on the picture.

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell, attention_mechanism,
    attention_layer_size=num_units)
查看更多
登录 后发表回答