12.4 基于点乘的多头注意力机制 417
12.4 基于点乘的多头注意力机制
Transformer 模型摒弃了循环单元和卷积等结构,完全基于注意力机制来构造模
型,其中包含着大量的注意力计算。比如,可以通过自注意力机制对源语言和目
语言行信取,并通-码注双语间的行建模。
12.9中红色方框部分是 Transformer 中使用注意力机制的模块。而这些模块都是
基于点乘的多头注意力机制实现的。
Self-Attention
Add & LayerNorm
Feed Forward Network
Add & LayerNorm
Embedding
+
Position
编码器输入:
编码器
Self-Attention
Add & LayerNorm
Encoder-Decoder Attention
Add & LayerNorm
Feed Forward Network
Add & LayerNorm
Output layer
Embedding
+
Position
解码器输入: <sos> I am fine
解码器
解码器输出: I am fine <eos>
12.9 自注意力机制在模型中的位置
12.4.1 点乘注意力机制
12.1节中绍,自注中至是获系数,
重。Transformer
乘的方法来计算相关性系数。这种方法也称缩放的点乘注意Scaled Dot-product
Attention)机制。它的运算并行度高,同时并不消耗太多的存储空间。
具体来看,在注意力机制的计算过程中,包含三个重要的参数,分别是 query
key value。在下面的描述中,分别用 QKV 它们进行表示,其 Q K
维度为 L ×d
k
V 的维度为 L ×d
v
。这里,L 为序列的长度,d
k
d
v
分别表示每个
key value 的大小,通常设置为 d
k
= d
v
= d
model
在自注意力机制中,QKV 都是相同的,对应着源语言或目标语言序列的表
示。而在编码-解码注意力机制中,由于要对双语之间的信息进行建模,因此,将目
标语言每个位置的表示视为编码-解码注意力机制的 Q源语言句子的表示视为 K
V
418 Chapter 12. 基于自注意力的模型 肖桐 朱靖波
在得到 QK V 后,便可以进行注意力的运算,这个过程可以被形式化为:
Attention(Q,K,V) = Softmax(
QK
T
d
k
+ Mask)V (12.9)
首先,通过对 Q K 的转置进行矩阵乘法操作,计算得到一个维度大小为 L ×L
相关性矩阵, QK
T
它表示一个序列上任意两个位置的相关性。再通过系数 1/
d
k
进行放缩操作,放缩可以减少相关性矩阵的方差,具体体现在运算过程中实数矩
中的数值不会过大,有利于模型训练。
在此基础上,通过对相关性矩阵累加一个掩码矩阵 Mask来屏蔽掉矩阵中的无
用信息。比如,在编码器端,如果需要对多个句子同时处理,由于这些句子长度不统
一,需要对句子补齐。再比如,在解码器端,训练的时候需要屏蔽掉当前目标语言位
置右侧的单词,因此这些单词在推断的时候是看不到的。
随后,使用 Softmax 函数对相关性矩阵在行的维度上进行归一化操作,这可以理
解为对第 i 行进行归一化,结果对应了 V 中不同位置上向量的注意力权重。对于 value
的加权求和,可以直接用相关性系数和 V 进行矩阵乘法得到, Softmax(
QK
T
d
k
+Mask)
V 行矩乘。最得到自注力的出,它输入 V 小是一模样的。
12.10展示了点乘注意力的计算过程。
MatMul
Q
K
Scale
Mask(opt.)
SoftMax
MatMul
V
自注意力机制的 Query
Key Value 均来自同一句
子,编码-解码注意力机制
与前面讲的一样
Query Key 的转置进
行点积, 得到句子内部
各个位置的相关性
相关性矩阵在训练中
方差变大,不利于训练
所以对其进行缩放
在编码器端,对句子补齐
填充的部分进行屏蔽
解码时看不到未来的信息
需要对未来的信息进行屏蔽
用归一化的相关性打分
Value 进行加权求和
12.10 点乘注意力的计算过程
下面举个简单的例子介绍点乘注意力的具体计算过程。如图12.11所示,用黄色、
蓝色和橙色的矩阵分别表示 QK VQK V 中的每一个小格都对应一个单词
在模型中的表示(即一个向量)首先,通过点乘、放缩、掩码等操作得到相关性矩
阵,即粉色部分。其次,将得到的中间结果矩阵(粉色)的每一行使用 Softmax 激活
函数进行归一化操作,得到最终的权重矩阵,也就是图中的红色矩阵。红色矩阵
的每一行都对应一个注意力分布。最后,按行 V 进行加权求和,便得到了每个单
词通过点乘注意力计算得到的表示。这里面,主要的计算消耗是两次矩阵乘法,
Q K
T
的乘法、相关性矩阵和 V 的乘法。这两个操作都可以在 GPU 上高效地完成,
因此可以一次性计算出序列中所有单词之间的注意力权重,并完成所有位置表示
12.4 基于点乘的多头注意力机制 419
加权求和过程,这样大大提高了模型计算的并行度。
Attention(
Q
,
K
,
V
)
=
Softmax(
Q
×
K
T
d
k
Mask
+
)
V
=
Softmax( )
按行进行 Softmax
V
=
×
V
=
12.11 公式(12.9)的执行过程示例
12.4.2 多头注意力机制
Transformer 使Multi-head Attention
“多头”可以理解成将原来的 QKV 照隐藏层维度平均切分成多份。假设
h 份,那么最终会得到 Q = {Q
1
,...,Q
h
}K = {K
1
,...,K
h
}V = {V
1
,...,V
h
}多头
注意力就是用每一个切分得到的 QKV 独立的进行注意力计算,即第 i 个头的注
意力计算结果 head
i
= Attention(Q
i
,K
i
,V
i
)
下面根据图12.12详细介绍多头注意力的计算过程:
首先, QKV 分别线性Linear换的映射 h 个子集。
Q
i
= QW
Q
i
K
i
= KW
K
i
V
i
= VW
V
i
,其 i 表示 i 个头,W
Q
i
R
d
model
×d
k
,
W
K
i
R
d
model
×d
k
, W
V
i
R
d
model
×d
v
是参数矩阵; d
k
= d
v
= d
model
/h,对于不同
头采用不同的变换矩阵,这里 d
model
表示每个隐藏层向量的维度;
其次,对每个头分别执行点乘注意力操作,并得到每个头的注意力操作的输出
head
i
最后, h 个头的注意力输出在最后一维 d
v
进行拼接Concat重新得到维度
hd
v
的输出,并通过对其右乘一个权重矩阵 W
o
进行线性变换,从而对多头
计算得到的信息进行融合,且将多头注意力输出的维度映射为模型的隐藏层大
小(即 d
model
,这里参数矩阵 W
o
R
hd
v
×d
model
420 Chapter 12. 基于自注意力的模型 肖桐 朱靖波
Linear
Linear
Linear
Q
Linear
Linear
Linear
K
Linear
Linear
Linear
V
Scaled Dot-Product Attention
Concat
Linear
12.12 多头注意力模型
多头注意力机制可以被形式化描述为公式(12.10)(12.11)
MultiHead(Q,K,V) = Concat(head
1
,...,head
h
)W
o
(12.10)
head
i
= Attention(QW
Q
i
,KW
K
i
,VW
V
i
) (12.11)
多头注意力机制的好处是允许模型在不同的表示子空间里学习。在很多实验
发现,不同表示空间的头捕获的信息是不同的,比如,在使用 Transformer 处理自然
语言时,有的头可以捕捉句法信息,有的头可以捕捉词法信息。
12.4.3 掩码操作
在公式(12.9)中提到了掩码Mask它的目的是对向量中某些值进行掩盖,避免
无关位置的数值对运算造成影响。Transformer 中的掩码主要应用在注意力机制中
相关性系数计算,具体方式是在相关性系数矩阵上累加一个掩码矩阵。该矩阵在
要掩码的位置的值为负无穷 inf(具体实现时是一个非常小的数,比如 1e9,其
余位置为 0这样在进行了 Softmax 归一化操作之后,被掩码掉的位置计算得到的权
重便近似为 0也就是说对无用信息分配的权重为 0从而避免了其对结果产生影响。
Transformer 包含两种掩码:
句长补全掩码Padding Mask。在批量处理多个样本时(训练或解码),由于
要对源语言和目标语言的输入进行批次化处理,而每个批次内序列的长度不一
样,为了方便对批次内序列进行矩阵表示,需要进行对齐操作,即在较短的序
列后面填充 0 来占位(padding 操作)。而这些填充 0 的位置没有实际意义,
参与注意力机制的计算,因此,需要进行掩码操作,屏蔽其影响。
未来信息掩码Future Mask对于解码器来说,由于在预测的时候是自左向右
进行的,即第 t 时刻解码器的输出只能依赖于 t 时刻之前的输出。且为了保证训
练解码一致,避免在训练过程中观测到目标语言端每个位置未来的信息,因此
需要对未来信息进行屏蔽。具体的做法是:构造一个上三角值全为-inf Mask
12.4 基于点乘的多头注意力机制 421
矩阵,也就是说,在解码器计算中,在当前位置,通过未来信息掩码把序列之后
的信息屏蔽掉了,避免了 t 时刻之后的位置对当前的计算产生影响。12.13
出了一个具体的实例。
Have
you
learned
nothing
?
eos
Have
you
learned
nothing
?
eos
Masked
Have
you
learned
nothing
?
eos
Have
you
learned
nothing
?
eos
12.13 Transformer 模型对未来位置进行屏蔽的掩码实例