在看attention is all you need这篇论文时,对其中的描述理解不透彻,结合代码,详细的跑下整体的流程,总结了几个问题,记录下来。
为什么attention在 Q ∗ K T Q*K^T Q∗KT之后需要除一个 d k \sqrt{d_k} dk ?attention有加法atttention(NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE, 2014):
、点乘attention,其中,点乘的效率要高于加法attention。 在 d k d_k dk较小时,两者效果接近; d k d_k dk较大时,加法attenion的效果要优于乘法attention。作者怀疑在 d k d_k dk较大时,点击的结果在数量级上变大,导致softmax后值落在两端,反向传播时梯度大都落在极小的区域。不利于梯度更新,于是作者除了一个 d k \sqrt{d_k} dk 来消除这种影响。
多头attention是如何实现多头的? 多头attention的操作其实是在一个矩阵乘法中完成的,详细的前向代码: def forward(self, query, key, value, mask=None): "Implements Figure 2" if mask is not None: # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x)其中,函数输入的query.shape=key.shape=value.shape=[30,9,512](在self-attention里面,query、key、value其实都是同一个输入) self.h = 8 (头的个数) self.d_k = 64 (每个头下单词向量的维度) 完成的词向量是在多头attention里面, 主要步骤有3步:
线性变换+将矩阵形状改成多头状。 线性变换后输入为(30,9,512); 改变矩阵的形状为(30,9,8,64);这里将矩阵的词向量切为8*64大小; 交换1,2维度,变为(30,8,9,64),这里把每个头挪到单词个数前,保证每个头利用单词与部分词向量的信息做计算对每个头的(9,64)维向量做attention计算,attention代码: def attention(query, key, value, mask=None, dropout=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn在attention代码中,针对输入数据,其实一次矩阵乘法就把8个头全部计算好了,可以发现,实际参与矩阵乘法的只有最后两个维度,也就是(9,64)这两个维度,softmax后计算出来的p_attn就是attention矩阵,表示目前的query和key对value关注的侧重点,可以将p_attn可视化来观察对value的注意力变化 其中,输入的query, key, value的shape为(30,8,9,64) scores.shape=p_atten.shape=(30,8,9*9) (表示单词与单词之间的注意力,8表示多个头,其中每个头的注意力并不一样,类似CNN的feature_map) 返回结果的shape为(30,8,9,64),与输入形状相同
恢复经过多头attention计算后的x的形状,在第二步计算最后的x.shape=(30,8,9,64),通过调整头与单词维度,将shape还原为(30,9,512)在数据维度,我们观察下整个shape变化是否符合预期: 在第一步线性变换和view函数后,我们取0号batch0号单词的512维向量:
>>> self.linears[0](query).view(nbatches, -1, self.h, self.d_k)[0,0,:] tensor([[-5.5849e-01, -1.1345e+00, -1.8808e+00, -1.2825e+00, -2.6026e-01, -2.7234e-01, 7.9375e-01, -1.8540e+00, 7.1830e-01, 8.3990e-01, 2.3629e+00, 1.2487e+00, -1.3728e+00, 1.4039e+00, -3.7715e-02, 8.8309e-01, -1.5008e+00, 2.2860e-02, -1.6180e-01, -2.2477e-01, 1.1788e+00, 5.5054e-01, -4.0445e-01, -1.7874e+00, 7.3571e-01, -2.4591e+00, -1.0152e+00, 3.0960e-01, 6.2231e-01, -2.6062e-01, -1.6349e+00, 2.9838e-01, 2.9148e-01, -9.8843e-01, -2.2181e-01, -2.4378e+00, -1.1140e+00, -1.6335e+00, -9.5546e-01, 8.9785e-01, -5.6034e-01, 7.9584e-01, -2.4349e-01, 8.4613e-01, -1.0000e+00, -2.3458e+00, -1.2131e+00, 1.4844e-01, -1.0343e+00, -1.1337e+00, -5.0574e-01, -1.2185e+00, -1.0080e-01, -1.7573e+00, 1.1949e-01, 4.3815e-01, -5.1875e-01, -9.6025e-01, 4.9408e-01, -3.1510e-01, -9.1880e-01, 7.8462e-01, -7.0183e-01, 8.5226e-01], [ 2.2000e+00, 2.1699e+00, -1.0479e+00, 9.9278e-01, 1.6606e+00, 1.2315e+00, -5.6103e-01, 1.6627e-01, 8.1271e-01, -8.0781e-01, -9.2971e-01, -2.3849e-01, -7.4068e-01, -1.6442e+00, 6.0797e-01, 1.9765e+00, -6.0977e-01, 1.9683e-01, -1.0796e+00, 9.9682e-01, -3.4797e-01, 7.0690e-01, 2.8246e-02, -5.7680e-02, -1.3622e+00, -3.3336e-01, 1.8560e+00, -8.4779e-01, 3.7205e-01, -1.0761e+00, 3.1989e-01, 3.3015e-01, 6.2160e-01, -6.2217e-01, -6.0230e-01, -1.1111e+00, 8.1014e-02, -1.5747e+00, 7.7337e-01, 2.0376e-01, -4.4982e-01, 1.1864e+00, 1.3114e+00, 1.3731e-03, 6.6610e-01, -7.4975e-01, 8.4594e-01, 4.6093e-01, -1.2744e+00, -2.4648e-01, -2.2596e+00, -1.4428e-01, 3.4453e-01, -1.2783e+00, 3.3285e-01, -1.2139e+00, 1.1876e+00, 4.6992e-01, 1.1853e+00, -5.9884e-01, -7.2868e-01, -3.0325e-01, 5.0907e-01, 1.7765e+00], [ 1.3094e+00, 8.8078e-01, 6.2083e-01, 6.3161e-01, -4.4594e-01, 5.4613e-01, -8.5386e-01, 6.5830e-01, 2.0159e+00, -3.7242e-04, 1.1343e+00, 3.5900e-01, -7.5025e-01, 5.2789e-01, -4.8372e-01, 1.0444e-01, -2.8277e-01, -1.5276e+00, -1.8052e+00, 4.9059e-01, -5.7909e-01, -2.6016e+00, -9.1482e-01, -2.0798e+00, 4.3872e-01, -4.5775e-01, 4.2443e-01, 9.5000e-01, 1.2905e+00, 7.7594e-01, 1.3758e+00, 8.0112e-01, 2.6953e-01, 8.5166e-01, -1.9281e+00, 5.2717e-01, 3.5669e-01, -7.1228e-01, -1.5550e+00, -1.4948e+00, -6.4031e-01, 1.3307e-01, -2.0452e+00, -6.8747e-01, 6.8062e-01, 7.0740e-01, 2.3751e-01, -7.5977e-01, 1.1818e+00, -3.9343e-02, -1.2286e+00, -1.6509e-01, -4.7595e-01, -5.6454e-01, 6.8856e-01, -1.0079e+00, 2.0817e-02, 4.7811e-01, 1.1828e+00, 1.9962e-01, 1.3741e-01, -8.6575e-01, -1.4432e+00, -3.1736e-01], [-1.8337e-01, -5.4589e-01, -1.7105e+00, 1.4152e+00, 9.9637e-01, 6.4223e-01, -5.7173e-01, -3.1380e-01, 4.1919e-01, -6.4363e-01, 2.2370e-01, -3.3537e-01, -1.9758e-01, 5.0873e-02, -2.6714e-01, 3.7421e-01, 6.9361e-01, -8.7237e-01, -4.4104e-01, 9.9614e-01, -2.4303e-01, -1.1117e+00, 3.1751e-01, 1.5587e+00, -2.2545e-01, -2.0448e+00, -1.7236e+00, -4.7648e-01, -4.9185e-02, -4.4533e-01, 1.2060e+00, 8.1677e-02, 3.4257e-01, -3.1663e-01, 7.4717e-01, -1.4853e+00, -1.8689e+00, -1.8789e+00, 6.2522e-01, -4.7919e-01, 6.8370e-01, -1.1312e+00, -2.7003e-01, 4.0155e-01, -6.8372e-01, -4.4206e-01, 4.6632e-01, -7.0613e-01, 3.6363e-01, 1.3394e+00, -3.9506e-01, -1.9903e-01, -5.2489e-01, -5.5750e-01, -4.9276e-01, 1.8087e+00, -1.8388e-01, 1.4550e+00, -1.3432e+00, -5.5853e-01, 8.4526e-01, 3.7190e-01, 8.4095e-01, 1.4572e+00], [-6.5964e-01, -1.0118e+00, -3.3534e-01, 5.9344e-01, 1.5215e+00, 4.1303e-01, -6.6831e-01, -1.2634e+00, 8.9694e-01, -8.4749e-01, 9.8277e-02, 1.0111e+00, -8.0979e-01, -6.9347e-01, -6.3610e-01, -8.6293e-01, 6.9013e-01, 2.8582e-01, 5.4475e-01, 2.0056e+00, -3.9130e-01, 6.3051e-01, -6.1732e-01, 6.3364e-01, 9.0227e-01, 1.5292e+00, -1.1140e+00, -2.4449e-01, -2.2558e+00, 5.5462e-01, 7.9631e-01, 3.5409e-02, -1.1239e+00, 1.3949e-01, -3.3107e-01, -1.6686e+00, 5.6108e-01, 2.0968e+00, 1.4876e+00, -1.6770e+00, 1.4968e+00, -2.1912e+00, -5.0450e-01, -3.2774e-01, -4.1009e-01, -6.8031e-01, -8.7829e-01, 1.3209e+00, 5.4851e-01, -3.5644e-01, 2.0661e-01, -7.9915e-02, -8.7987e-02, -3.5153e-01, 5.7297e-01, -1.2522e+00, -3.1542e-01, 3.5657e-01, 8.3919e-02, -1.1384e+00, 6.5040e-01, 5.2673e-01, -3.7128e-02, -5.6037e-01], [ 3.7929e-01, 1.0985e+00, -9.5383e-01, -1.5716e+00, 1.0732e+00, 7.6176e-01, 4.9928e-01, -4.6279e-01, 1.1763e+00, 6.4880e-01, -2.3453e-01, 1.4645e+00, 1.1831e+00, 1.1062e+00, -6.6781e-01, 1.3249e-01, 4.5694e-01, -8.2162e-01, -5.3298e-01, 2.3301e-01, 1.4458e+00, -7.2901e-01, -3.7504e-01, -1.1897e-01, 3.9508e-01, 1.6462e+00, -9.2095e-01, 7.2290e-01, -1.0090e+00, -1.1971e+00, 9.3543e-01, -1.0169e+00, 2.2645e+00, -9.0345e-01, -3.9019e-01, 2.5347e+00, -4.9218e-01, 3.2157e-01, -8.0402e-01, -6.7765e-02, -1.2106e+00, -7.1914e-01, 1.5650e-01, 2.2837e+00, 4.2052e-01, -9.9606e-01, 9.0366e-01, 1.2647e-01, -6.7506e-01, -2.5810e-01, 9.9852e-01, 4.7259e-02, -1.5180e+00, -1.4949e-02, -9.2633e-01, 2.5982e-01, 1.1767e+00, -9.1571e-01, 2.6524e-01, -1.8769e+00, 1.8396e+00, 1.7582e-01, 7.0390e-02, -6.0531e-01], [-7.4824e-01, 1.3541e+00, 8.1317e-02, -1.5941e+00, -9.1559e-01, 3.4743e-02, 4.0053e-01, -2.2267e+00, 2.2170e+00, -1.8526e+00, -3.9096e-01, 4.7746e-01, 6.3766e-01, -1.3066e+00, 1.8898e+00, 9.0138e-01, -6.2393e-01, 2.6871e+00, 2.2450e+00, -4.9805e-01, -1.0026e+00, -8.2982e-01, -6.0404e-01, 8.6061e-01, -1.7860e+00, -3.3738e-01, -1.2521e+00, 1.8993e-01, 6.6285e-01, 1.8712e+00, -3.8850e-01, -5.3199e-02, -2.0127e+00, -9.6038e-01, -6.9656e-01, 8.9531e-01, -1.5444e+00, -9.8656e-01, 1.8569e-01, 5.3418e-01, 1.8387e+00, 2.6336e+00, -8.5806e-01, -2.1681e+00, 1.1165e+00, 6.5348e-01, -1.0500e+00, 9.5539e-01, 1.1112e+00, -7.5706e-02, -2.3024e-01, 2.8802e-01, -7.9896e-01, -6.1654e-01, 9.9108e-01, -4.2569e-01, 1.1832e+00, 1.0911e+00, 1.5790e-01, 6.5871e-01, -1.6219e+00, 4.1190e-02, -1.8713e-01, -3.0448e-02], [ 1.0487e+00, -2.2669e-01, -1.0977e+00, -2.2179e-01, -3.0139e-01, 2.1485e+00, -3.3073e-02, 8.8904e-01, 3.2826e-01, -4.4065e-01, 8.8462e-01, 2.1771e+00, -1.2034e+00, 6.3978e-01, -3.5032e-03, 2.7921e-01, 9.1261e-01, -5.9244e-01, -5.8784e-01, 7.9417e-02, 6.8320e-01, 1.8481e+00, 4.0412e-01, -1.6823e+00, 1.0622e-01, -2.0569e-01, 1.6752e+00, 5.2697e-01, 4.1997e-01, -7.7281e-01, 6.7034e-01, 1.7379e+00, 1.2911e+00, -2.9531e-01, -1.1818e+00, -2.6956e-01, 1.9731e+00, 1.4166e+00, 4.2653e-01, -1.6993e+00, 5.3233e-01, -1.8343e-01, -1.0087e+00, -1.2781e-02, 1.1633e-01, 5.0240e-01, 3.5670e-01, -1.1860e-01, -7.4418e-01, -7.2034e-01, 8.5559e-01, 5.1946e-01, -7.7406e-01, 3.8031e-01, -1.5740e-01, 4.9406e-01, -1.1626e+00, 1.8839e+00, -2.3393e-01, -2.0924e-01, -1.4733e+00, 1.0634e+00, -6.0484e-01, -1.2338e-01]], grad_fn=<SliceBackward>)在经过transpose后,我们取0号batch1号头0号单词的64维向量,照理说应该与上面512维向量中的0号batch0号单词的64维到128维一致
>>> self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1,2)[0,1,0,:] tensor([ 2.2000e+00, 2.1699e+00, -1.0479e+00, 9.9278e-01, 1.6606e+00, 1.2315e+00, -5.6103e-01, 1.6627e-01, 8.1271e-01, -8.0781e-01, -9.2971e-01, -2.3849e-01, -7.4068e-01, -1.6442e+00, 6.0797e-01, 1.9765e+00, -6.0977e-01, 1.9683e-01, -1.0796e+00, 9.9682e-01, -3.4797e-01, 7.0690e-01, 2.8246e-02, -5.7680e-02, -1.3622e+00, -3.3336e-01, 1.8560e+00, -8.4779e-01, 3.7205e-01, -1.0761e+00, 3.1989e-01, 3.3015e-01, 6.2160e-01, -6.2217e-01, -6.0230e-01, -1.1111e+00, 8.1014e-02, -1.5747e+00, 7.7337e-01, 2.0376e-01, -4.4982e-01, 1.1864e+00, 1.3114e+00, 1.3731e-03, 6.6610e-01, -7.4975e-01, 8.4594e-01, 4.6093e-01, -1.2744e+00, -2.4648e-01, -2.2596e+00, -1.4428e-01, 3.4453e-01, -1.2783e+00, 3.3285e-01, -1.2139e+00, 1.1876e+00, 4.6992e-01, 1.1853e+00, -5.9884e-01, -7.2868e-01, -3.0325e-01, 5.0907e-01, 1.7765e+00], grad_fn=<SliceBackward>) >>> assert (self.linears[0](query).view(nbatches, -1, self.h, self.d_k)[0,0,1,:] == self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1,2)[0,1,0,:]).any()