上一章我们聊了聊quick-thought通过干掉decoder加快训练, CNN—LSTM用CNN作为Encoder并行计算来提速等方法,这一章看看抛开CNN和RNN,transformer是如何只基于attention对不定长的序列信息进行提取的。虽然Attention is All you need论文本身是针对NMT翻译任务的,但transformer作为后续USE/Bert的重要组件,放在embedding里也没啥问题。以下基于WMT英翻中的任务实现了transfromer,完整的模型代码详见DSXiangLi-Embedding-transformer
模型组件
让我们先过一遍Transformer的基础组件,以文本场景为例,encoder和decoder的输入是文本序列,每个batch都pad到相同长度然后对每个词做embedding得到batch * padding_len * emb_size的输入向量
假设batch=1,Word Embedding维度为512,Encoder的输入是'Fox hunt rabbit at night', 经过Embedding之后得到1 * 5 * 512的向量,以下的模型组件都服务于如何从这条文本里提取出更多的信息
Attention
序列信息提取的一个要点在于如何让每个词都考虑到它所在的上下文语境
- RNN:上下文信息靠向后/前传递,从前往后传rabbit就能考虑到fox,从后往前传rabbit就能考虑到night
- CNN:靠不同kernel_size定义的局部窗口来获取context信息, kernel_size>=3,rabbit就能考虑到所有其他token的信息
- Attention:通过计算词和上下文之间的相关性(广义),来决定如何把周围信息(value)融合(weighted-average)进当前信息(query),下图来源Reference5
Transformer在attention的基础上有两点改良, 分别是Scaled-dot product attention和multi-head attention。
Scaled-dot product attention
Attention的输入是三要素query,key和value,通过计算query和Key的相关性,这里是广义的相关,可以通过加法/乘法得到权重向量,用权重对value做加权平均作为输出。‘fox hunt rabbit at night’会计算每个词对所有词的相关性,得到[5, 5]的相似度矩阵/权重向量,来对输入[5, 512]进行加权,得到每个词在考虑上下文语义后新的向量表达[5, 512]
Transformer在常规的乘法attention的基础上加入\(d_k\)维度的正则化。这里\(d_k\)是query和key的特征维度,在我们的文本场景下是embedding_size[512] 。正则化的原因是避免高维embedding的内积出现超级大的值,导致softmax的gradient非常小。
直观解释,假设query和key的每个元素都独立服从\(\mu=0 \, \sigma^2=1\)的分布, 那内积\(\sum_{d_k}q_ik_i\)就服从\(\mu=0 \, \sigma^2=d_k\)的分布,因此需要用\(\sqrt{d_k}\)做正则化,保证内积依旧服从\(\mu=0 \, \sigma=1\)的分布。
def scaled_dot_product_attention(key, value, query, mask):
with tf.variable_scope('scaled_dot_product_attention', reuse=tf.AUTO_REUSE):
# scalaed weight matrix : batch_size * query_len * key_len
dk = tf.cast(key.shape.as_list()[-1], tf.float32)# emb_size
weight = tf.matmul(query, key, transpose_b=True)/(dk**0.5)
# apply mask: large negative will become 0 in softmax[mask=0 ignore]
weight += (1-mask) * (-2**32+1)
# normalize on axis key_len so that score add up to 1
weight = tf.nn.softmax(weight, axis=-1)
tf.summary.image("attention", tf.expand_dims(weight[:1], -1)) # add channel dim
add_layer_summary('attention', weight)
# weighted value: batch_size * query_len * emb_size
weighted_value = tf.matmul(weight, value )
return weighted_value
Mask
上面代码中的mask是做什么的呢?mask决定了Attention对哪些特征计算权重,transformer的mask有两种【以下mask=1是保留的部分,0是drop的部分】
其一是padding mask, 让attention的权重只针对真实文本计算其余为0。padding mask的dimension是[batch, 1, key_len], 1是预留给query,会在attention中被broadcast成[batch, query_len, key_len]
def seq_mask_gen(input_, params):
mask = tf.sequence_mask(lengths=tf.to_int32(input_['seq_len']), maxlen=tf.shape(input_['tokens'])[1],
dtype=params['dtype'])
mask = tf.expand_dims(mask, axis=1)
return mask
如果输入文本长度分别为3,4,5,都padding到5,padding mask维度是[3,1,5] 如下
其二是future mask只用于decoder,mask每个token后面的序列,保证在预测T+1的token时只会用到T及T以前的信息,如果不加future mask,预测T+1时就会用到T+1的文本本身,出现feature leakage。
def future_mask_gen(input_, params):
seq_mask = seq_mask_gen(input_, params) # batch_size * 1 * key_len
mask = tf.matmul(seq_mask, seq_mask, transpose_a=True) # batch_size * key_len * key_len
mask = tf.matrix_band_part(mask, num_lower=-1, num_upper=0)
return mask
还是上面的例子,future mask的维度是[3,5,5] 如下
multi-head attention
这些年对multi-head为啥有效的讨论有很多,下面Reference3~8都从不同方面给出了不同的Insight。最开始看multi-head的设计,第一反应是你莫不是在逗我?!你把