How to interpret weights in a LSTM layer in Keras

2019-02-09 06:21发布

问题:

I'm currently training a recurrent neural network for weather forecasting, using a LSTM layer. The network itself is pretty simple and looks roughly like this:

model = Sequential()  
model.add(LSTM(hidden_neurons, input_shape=(time_steps, feature_count), return_sequences=False))  
model.add(Dense(feature_count))  
model.add(Activation("linear"))  

The weights of the LSTM layer do have the following shapes:

for weight in model.get_weights(): # weights from Dense layer omitted
    print(weight.shape)

> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)
> (feature_count, hidden_neurons)
> (hidden_neurons, hidden_neurons)
> (hidden_neurons,)

In short, it looks like there are four "elements" in this LSTM layer. I'm wondering now how to interpret them:

  • Where is the time_steps parameter in this representation? How does it influence the weights?

  • I've read that a LSTM consists of several blocks, like an input and a forget gate. If those are represented in these weight matrices, which matrix belongs to which gate?

  • Is there any way to see what the network has learned? For example, how much does it take from the last time step (t-1 if we want to forecast t) and how much from t-2 etc? It would be interesting to know if we could read from the weights that the input t-5 is completely irrelevant, for example.

Clarifications and hints would be greatly appreciated.

回答1:

If you are using Keras 2.2.0

When you print

print(model.layers[0].trainable_weights)

you should see three tensors: lstm_1/kernel, lstm_1/recurrent_kernel, lstm_1/bias:0 One of the dimensions of each tensor should be a product of

4 * number_of_units

where number_of_units is your number of neurons. Try:

units = int(int(model.layers[0].trainable_weights[0].shape[1])/4)
print("No units: ", units)

That is because each tensor contains weights for four LSTM units (in that order):

i (input), f (forget), c (cell state) and o (output)

Therefore in order to extract weights you can simply use slice operator:

W = model.layers[0].get_weights()[0]
U = model.layers[0].get_weights()[1]
b = model.layers[0].get_weights()[2]

W_i = W[:, :units]
W_f = W[:, units: units * 2]
W_c = W[:, units * 2: units * 3]
W_o = W[:, units * 3:]

U_i = U[:, :units]
U_f = U[:, units: units * 2]
U_c = U[:, units * 2: units * 3]
U_o = U[:, units * 3:]

b_i = b[:units]
b_f = b[units: units * 2]
b_c = b[units * 2: units * 3]
b_o = b[units * 3:]

Source: keras code



回答2:

I probably won't be able to answer all of your questions but what I can do is provide more information on the LSTM cell and the different components that it's made of.

This post on github proposes a way to see the parameters' name while printing them :

model = Sequential()
model.add(LSTM(4,input_dim=5,input_length=N,return_sequences=True))
for e in zip(model.layers[0].trainable_weights, model.layers[0].get_weights()):
    print('Param %s:\n%s' % (e[0],e[1]))

Output looks like :

Param lstm_3_W_i:
[[ 0.00069305, ...]]
Param lstm_3_U_i:
[[ 1.10000002, ...]]
Param lstm_3_b_i:
[ 0., ...]
Param lstm_3_W_c:
[[-1.38370085, ...]]
...

Now you can find here more information about those different weights. They have names W, U, V and b with different indices.

  • W matrices are those that transform the inputs into some other internal values. They have the shape [input_dim, output_dim].
  • U matrices are those that transform the previous hidden state into another internal value. They have the shape [output_dim, output_dim].
  • b vectors are the bias for each block. They all have the shape [output_dim]
  • V is used only in the output gate, it choses which values to output from the new internal state. It has a shape [output_dim, output_dim]

In short you have indeed 4 different "blocks" (or internal layers).

  • forget gate : It decides, based on the previous hidden state (h_{t-1}) and the input (x), which values to forget from the previous internal state of the cell (C_{t-1}) :

    f_t = sigmoid( W_f * x + U_f * h_{t-1} + b_f )

    f_t is a vector of values between 0 and 1 that will encode what to keep (=1) and what to forget (=0) from the previous cell state.

  • Input gate : It decides, based on the previous hidden state (h_{t-1}) and the input (x), which values to use from the input (x) :

    i_t = sigmoid( W_i * x + U_i * h_{t-1} + b_i )

    i_t is a vector of values between 0 and 1 that will encode which values to use to update the new cell state.

  • Candidate value : We build new candidate values to update the internal Cell state, using the input (x) and the previous hidden state (h_{t-1}) :

    Ct_t = tanh( W_c * x + U_c * h_{t-1} + b_c )

    Ct_t is a vector containing potential values to update the Cell state (C_{t-1}).

We use those three values to build a new internal cell state (C_t):

C_t = f_t * C_{t-1} + i_t * Ct_t

as you can see, the new internal cell state is composed with two things : the part we didn't forget from the last state, and what we wanted to learn from the input.

  • Output gate : we don't want to output the cell state as it might be seen as an abstraction of what we want to output (h_t). So we build h_t, the output for this step based on all the information we have :

    h_t = W_o * x + U_o * h_{t-1} + V_o * C_t + b_o

I hope this clarifies how a LSTM cell works. I invite you to read tutorials on LSTM as they use nice schemas, step by step examples and so on. It is a relatively complex layer.

Regarding your questions, I have now idea how to track what has been used from the input to modify the state. You could eventually look at the different W matrices, as they are the ones processing the input. The W_c will give you information about what is potentially used to update the cell's state. W_o might give you some info about what is used to produce the output... But all of this will be relative to the other weights as the previous states has also an influence.

If you see some strong weights in W_c however, it might not mean anything, because the input gate (i_t) will maybe be completely closed and anihilating the update of the cell state... It is complex, the field of mathematics that traces back what's happening in a Neural Net is reaally complex.

Neural nets are really black boxes for the most general case. You can find in the litterature some cases where they trace back information from output to input but this is in very special cases from what I have read.

I hope this helps :-)