I'm using (keras-self-attention) to implement attention LSTM in KERAS. How can I visualize the attention part after training the model? This is a time series forecasting case.
from keras.models import Sequential
from keras_self_attention import SeqWeightedAttention
from keras.layers import LSTM, Dense, Flatten
model = Sequential()
model.add(LSTM(activation = 'tanh' ,units = 200, return_sequences = True,
input_shape = (TrainD[0].shape[1], TrainD[0].shape[2])))
model.add(SeqSelfAttention())
model.add(Flatten())
model.add(Dense(1, activation = 'relu'))
model.compile(optimizer = 'adam', loss = 'mse')
One approach is to fetch the outputs of
SeqSelfAttention
for a given input, and organize them so to display predictions per-channel (see below). For something more advanced, have a look at the iNNvestigate library (usage examples included).Explanation:
show_features_1D
fetcheslayer_name
(can be a substring) layer outputs and shows predictions per-channel (labeled), with timesteps along x-axis and output values along y-axis.input_data
= single batch of data of shape(1, input_shape)
prefetched_outputs
= already-acquired layer outputs; overridesinput_data
max_timesteps
= max # of timesteps to showmax_col_subplots
= max # of subplots along horizontalequate_axes
= force all x- and y- axes to be equal (recommended for fair comparison)show_y_zero
= whether to show y=0 as a red linechannel_axis
= layer features dimension (e.g.units
for LSTM, which is last)scale_width, scale_height
= scale displayed image width & heightdpi
= image quality (dots per inches)Visuals (below) explanation:
print(outs_1)
reveals that all magnitudes are very small and don't vary much, so including the y=0 point and equating axes yields a line-like visual, which can be interpreted as self-attention being bias-oriented.batch_shape
instead ofinput_shape
removes all?
in printed shapes, and we can see that first output's shape is(10, 60, 240)
, second's(10, 240, 240)
. In other words, the first output returns LSTM channel attention, and the second a "timesteps attention". The heatmap result below can be interpreted as showing attention "cooling down" w.r.t. timesteps.SeqWeightedAttention is a lot easier to visualize, but there isn't much to visualize; you'll need to rid of
Flatten
above to make it work. The attention's output shapes then become(10, 60)
and(10, 240)
- for which you can use a simple histogram,plt.hist
(just make sure you exclude the batch dimension - i.e. feed(60,)
or(240,)
).SeqWeightedAttention example per request: