手把手教你搭建Transformer模型

markdown

Transformer模型出自论文“Attention is All You Need”,自其问世以来就迅速席卷了自然语言处理领域,并在各类主流任务上取得了新的突破,包括机器翻译、语言建模、序列标注和文本分类等。NiuTensor框架包含了Transformer的高效实现(NiuTensor/source/sample/transformer),本文以机器翻译任务为例,自顶向下对该结构进行分解,结合代码一步一步搭建完整的模型。

Transformer概览

Transformer由Encoder和Decoder两部分组成,如图1所示,使用Encoder-Decoder模型进行机器翻译可以被概括为:输入源语通过模型计算输出目标语。

markdown

图 1.1 基于Encoder-Decoder的机器翻译模型

在Transformer中,encoder和decoder都由多层相同的结构堆叠而成,如图1.2所示。(6x代表将同样的结构堆叠6次)

markdown

图 1.2 Transformer中的多层encoder和decoder

其中,每个encoder层又由两部分组成(先不讨论残差连接与正则化操作),包括一个前馈神经网络(FeedForward Neural Network,FNN)和一个自注意力层。(self-attention)encoder最底层的输入就是源语,先经过self-attention层,再传递到FNN层。

而decoder由三部分组成(同样先不讨论残差与正则),其self-attention最底层的输入是目标语,与encoder相比,它多出一个encoder-decoder attention层,其输入为encoder最后一层的输出。decoder顶层的输出依次经过线性变换和归一化得到目标词的概率。

结合代码详解Transformer

上面的图片概述了Transformer模型的大致结构,现在从输入开始描述整个模型的计算流程。

首先是将输入的词汇转换为词向量,Transformer中的词向量由两部分相加而成,包括原始的词嵌入(Embeddings)和位置信息编码。固定的位置信息编码公式如下,其中 p o s pos 代表当前词汇在整个输入句子中的相对位置,而 i i 代表了词向量维度,在原文中是512。(也可以替换为可学习的位置向量,即随模型训练而更新。)

词向量和位置信息编码

P E ( p o s , 2 i ) = sin ( p o s / 1000 0 2 i / d m o d e l ) P E_{(p o s, 2 i)}=\sin \left(p o s / 10000^{2 i / d_{\mathrm{model}}}\right)

Transformer位置信息编码在NiuTensor的实现如下:

/* 
make positional embeddings (of size eSize * length)
>> eSize - embedding size
>> d - dimension size of the hidden layers
>> length - length of the sequence
*/
void T2TEmbedder::MakePosEmbedding(int eSize, int d, int length)
{
    InitTensor2D(&posEmbeddingBase, length, eSize, X_FLOAT, devID);

    float * data = new float[posEmbeddingBase.unitNum];

    for(int pos = 0; pos < length; pos++){
        float * dp = data + pos * eSize;
        
        int channelSize = eSize / 2;
        int offset = 0;
        for(int i = 0; i < channelSize; i++){
            dp[offset++] = (float)sin(pos/pow(10000.0F, 2.0F*i/(d - 2)));
        }
        for(int i = 0; i < channelSize; i++){
            dp[offset++] = (float)cos(pos/pow(10000.0F, 2.0F*i/(d - 2)));
        }
    }

    posEmbeddingBase.SetData(data, posEmbeddingBase.unitNum);

    delete[] data;
}

encoder层计算方式

接下来词向量被送入self-attention层,然后经过残差连接、层归一化操作,送入FNN层,再经过残差连接和层归一化操作,如图2.1所示。

markdown

图 2.1 encoder层计算细节

Transformer中encoder部分对应的NiuTensor实现:

/* 
make the encoding network
>> input - the input tensor of the encoder
>> mask - the mask that indicate each position is valid
>> maskEncDec - no use
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttEncoder::Make(XTensor &input, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{
    XTensor x;

    x = embedder.Make(input);

    /* dropout */
    if(isTraining && dropoutP > 0)
        x = Dropout(x, dropoutP);

    for(int i = 0; i < nlayer; i++){
        XTensor att;
        XTensor ln;
        XTensor fnn;
        XTensor res;

        /* self attention */
        att = attentions[i].MakeBig(x, mask, isTraining);
        
        /* dropout */
        if(isTraining && dropoutP > 0)
            att = Dropout(att, dropoutP);

        /* residual connection */
        res = Sum(att, x);

        /* layer normalization */
        x = attLayerNorms[i].Make(res);

        /* fnn */
        fnn = fnns[i].Make(x, isTraining);

        /* dropout */
        if(isTraining && dropoutP > 0)
            fnn = Dropout(fnn, dropoutP);

        /* residual connection */
        res = Sum(fnn, x);

        /* layer normalization */
        x = fnnLayerNorms[i].Make(res);
    }
    
    x.SetName(ENCODING_NAME);
    input.SetName(ENCODING_INPUT_NAME);

    return x;
}

attention层计算方式

下面展示attention部分的计算方式,在NiuTensor的实现中包括两部分:

  1. 将输入的 k k , q q , v v 词向量进行线性转换。
  2. 使用转换后的 k k , q q , v v 计算attention权重。

如图2.2所示,self-attention层中的 k k , q q , v v 是同一个东西,在encoder端最底层是源语词向量,在decoder端最底层是目标语的词向量。而encoder-decoder-attention层中的 k k v v 是encoder端最顶层的输出,其中的 q q 在最底层是目标语的词向量。

markdown

图 2.2 self-attention层中对输入进行线性转换

NiuTensor中对输入的 k k , q q , v v 进行线性变换的实现:

/*
make the network given a big tensor that keeps keys, queries and values
>> kqv - the big tensor
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeBig(XTensor &kqv, XTensor &mask, bool isTraining)
{
    XTensor k2;
    XTensor q2;
    XTensor v2;
    XTensor kqv2;
    TensorList split;
    
    kqv2 = MMul(kqv, wbig);
    
    int d1 = kqv2.GetDim(0);
    int d2 = kqv2.GetDim(1);
    int d3 = kqv2.GetDim(2) / 3;
    
    InitTensor3D(&k2, d1, d2, d3, X_FLOAT, devID);
    InitTensor3D(&q2, d1, d2, d3, X_FLOAT, devID);
    InitTensor3D(&v2, d1, d2, d3, X_FLOAT, devID);
    
    split.Add(&q2);
    split.Add(&k2);
    split.Add(&v2);
    
    Split(kqv2, split, 2, 3);
    
    return MakeAttention(k2, q2, v2, mask, isTraining);
}

在计算multihead attention时,按照如下公式:

 Attention  ( Q , K , V ) = softmax ( Q K T d k ) V \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V

 MultiHead  ( Q , K , V ) =  Concat  (  head  1 , ,  head  h ) W O  where head  i =  Attention  ( Q W i Q , K W i K , V W i V ) \begin{aligned} \text { MultiHead }(Q, K, V) &=\text { Concat }\left(\text { head }_{1}, \ldots, \text { head }_{\mathrm{h}}\right) W^{O} \\ \text { where head }_{\mathrm{i}} &=\text { Attention }\left(Q W_{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\right) \end{aligned}

其中,输入的 k k , q q , v v (线性变换后的矩阵)先被等分为多个小矩阵(多头),原文中是8个头,多头分别进行attention计算,最后再拼接为大矩阵,经过线性变换再送到下一层,见图2.3。

markdown

图 2.3 self-attention层中对输入进行线性转换

NiuTensor中attention操作的实现:

/*
make the attention network given keys, queries and values (after linear transformation)
>> k - keys. It might be of size B * L * H
    where B = batch size, L = sequence length,
    and H = vector size of each position
>> q - queries
>> v - values
>> mask - as it is
>> isTraining - indicates whether the model is used for training
*/
XTensor T2TAttention::MakeAttention(XTensor &k, XTensor &q, XTensor &v, XTensor &mask, bool isTraining)
{
    XTensor kheads;
    XTensor qheads;
    XTensor vheads;
    
    /* multi head */
    kheads = Split(k, k.order - 1, nhead);
    qheads = Split(q, q.order - 1, nhead);
    vheads = Split(v, v.order - 1, nhead);
    
    XTensor att;
    XTensor dot;
    XTensor scalar;
    
    /* scalar = softmax(Q * K^T / sqrt(dk)) * V */
    dot = BMMul(qheads, X_NOTRANS, kheads, X_TRANS);
    
    if(isMasked)
        dot = dot + mask;
    
    dot = Linear(dot, 1.0F/(float)sqrt((float)dk/nhead));
    
    scalar = Softmax(dot, -1);

    if(isTraining && dropoutP > 0)
        scalar = Dropout(scalar, dropoutP);
    
    att = BMMul(scalar, vheads);
    
    /* concatenate the heads */
    return MMul(Merge(att, att.order - 1), wa);
}

残差连接与层正则化

在得到attention权重之后,将其与该层的输入进行残差连接,之后进行层正则化操作,NiuTensor中层正则化的实现:

/*
make the network
for each layer representation x, we have
y =
>> input - the input tensor
>> return - layer normalization output
*/
XTensor T2TLN::Make(XTensor &input)
{
    XTensor &x = input;
    XTensor xn;
    XTensor mean;
    XTensor variance;
    XTensor standard;
    XTensor meanFilled;
    XTensor standardFilled;

    /* \mu = (sum_i x_i)/m */
    mean = ReduceMean(x, x.order - 1);

    /* \sigma = (sum_i (x_i - \mu)^2)/m */
    variance = ReduceVariance(x, x.order - 1, mean);

    /* standard = sqrt(variance) */
    standard = Power(variance, 0.5F);

    /* unsqueeze mean and standard deviation to fit them into
    the same shape of x */
    meanFilled = Unsqueeze(mean, x.order - 1, x.GetDim(-1));
    standardFilled = Unsqueeze(standard, x.order - 1, x.GetDim(-1));

    /* x' = (x - \mu)/standard */
    xn = (x - meanFilled) / standardFilled;

    /* result = x' * w + b   */
    return xn * w + b;
}

FNN层计算方式

接下来是前馈神经网络层,即FNN,在NiuTensor中的实现如下:

/* 
make the network 
y = max(0, x * w1 + b1) * w2 + b2
>> input - the input tensor
>> return - the output tensor 
*/
XTensor T2TFNN::Make(XTensor &input, bool isTraining)
{
    XTensor t1;

    /* t1 = max(0, x * w1 + b1) */
    //t1 = Rectify(MMul(input, w1) + b1);
    t1 = Rectify(MulAndShift(input, w1, b1));
    
    if(isTraining && dropoutP > 0)
        t1 = Dropout(t1, dropoutP);

    /* result = t1 * w2 + b2 */
    //return MMul(t1, w2) + b2;
    return MulAndShift(t1, w2, b2);
}

decoder层计算方式

现在我们完成了整个encoder端的计算,而decoder端与之相比多出了计算encoder-decoder attention的部分,如图2.4所示。

markdown

图 2.5 decoder部分流程图

encoder-decoder attention操作加在self-attention和FNN之间,同样具有 k k , q q , v v 三个输入,其中 k k v v 均为encoder顶层的输出,而 q q 是decoder端的输入。在解码时,直到预测出结束符号才停止计算。最后,对decoder的输出进行线性变换和softmax操作,即可得到词表中每个目标词汇的概率。

NiuTensor中对应decoder部分的实现如下:

/* 
make the decoding network
>> inputDec - the input tensor of the decoder
>> outputEnc - the output tensor of the encoder
>> mask - mask that indicates which position is valid
>> maskEncDec - mask for the encoder-decoder attention
>> isTraining - indicates whether the model is used for training
<< return - the output tensor of the encoder
*/
XTensor AttDecoder::Make(XTensor &inputDec, XTensor &outputEnc, XTensor &mask, XTensor &maskEncDec, bool isTraining)
{
    XTensor x;

    x = embedder.Make(inputDec);

    /* dropout */
    if(isTraining && dropoutP > 0)
        x = Dropout(x, dropoutP);

    for(int i = 0; i < nlayer; i++){
        XTensor att;
        XTensor ende;
        XTensor ln;
        XTensor fnn;
        XTensor res;

        /******************/
        /* self attention */
        att = attentions[i].MakeBig(x, mask, isTraining);

        /* dropout */
        if(isTraining && dropoutP > 0)
            att = Dropout(att, dropoutP);

        /* residual connection */
        res = Sum(att, x);

        /* layer normalization */
        x = attLayerNorms[i].Make(res);

        /*****************************/
        /* encoder-decoder attention */
        ende = attentionsEnde[i].Make(outputEnc, x, outputEnc, maskEncDec, isTraining);

        /* dropout */
        if(isTraining && dropoutP > 0)
            ende = Dropout(ende, dropoutP);

        /* residual connection */
        res = Sum(ende, x);

        /* layer normalization */
        x = attEndeLayerNorms[i].Make(res);

        /*******/
        /* fnn */
        fnn = fnns[i].Make(x, isTraining);

        /* dropout */
        if(isTraining && dropoutP > 0)
            fnn = Dropout(fnn, dropoutP);

        /* residual connection */
        res = Sum(fnn, x);

        /* layer normalization */
        x = fnnLayerNorms[i].Make(res);
    }
    
    x.SetName(DECODING_NAME);

    return x;
}

最后附上NiuTensor中的完整的Transformer机器翻译模型实现:

/* 
make the network for machine translation (with the output softmax layer) 
>> inputEnc - input tensor of the encoder
>> inputDec - input tensor of the decoder
>> output - output tensor (distribution)
>> paddingEnc - padding of the sequences (on the encoder side)
>> paddingDec - padding of the sequences (on the decoder side)
>> isTraining - indicates whether the model is for training
*/
void T2TModel::MakeMT(XTensor &inputEnc, XTensor &inputDec, XTensor &output, XTensor &paddingEnc, XTensor &paddingDec, bool isTraining)
{
    XTensor encoding;
    XTensor decoding;
    XTensor maskEnc;
    XTensor maskDec;
    XTensor maskEncDec;

    /* encoder mask */
    MakeMTMaskEnc(inputEnc, paddingEnc, maskEnc);
    
    /* decoder mask */
    MakeMTMaskDec(inputEnc, inputDec, paddingEnc, paddingDec, maskDec, maskEncDec);

    encoding = MakeEncoder(inputEnc, maskEnc, isTraining);

    decoding = MakeDecoder(inputDec, encoding, maskDec, maskEncDec, isTraining);

    outputLayer->Make(decoding, output);
}