RNN BPTT算法详细推导

    技术2024-11-07  7

    BPTT算法推导

    BPTT全称:back-propagation through time。这里以RNN为基础,进行BPTT的推导。

    BPTT的推导比BP算法更难,同时所涉及的数学知识更多,主要用到了向量矩阵求导、向量矩阵微分、向量矩阵的链式求导法则,想要完全理解掌握BPTT的推导,这些是基础工具。

    向量矩阵求导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/10750718.html

    RNN的BPTT推导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/6509630.html

    上图是RNN的经典图示。

    RNN的BPTT推导:

    在刘的博客中,损失函数为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数;

    但是按照这种配置,无法推导出后续的表达式;经过思考,我认为应该是以下的配置:

    损失函数为交叉熵损失函数(二元交叉熵损失函数),输出的激活函数应该为sigmoid函数,隐藏层的激活函数为tanh函数。(二分类问题)

    对于RNN,由于在序列的每个位置都有损失函数,因此最终的损失 L L L为: L = ∑ t = 1 τ L ( t ) = − ∑ t = 1 τ y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) L = \sum_{t=1}^{\tau}L^{(t)}=-\sum_{t=1}^{\tau}y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t) L=t=1τL(t)=t=1τytlogy^t+(1yt)log(1y^t)

    ∂ L ∂ c = ∑ t = 1 τ ∂ L t ∂ c = ∑ t = 1 τ ( y ^ t − y t ) \frac{\partial L}{\partial c} = \sum_{t=1}^{\tau}\frac{\partial L^t}{\partial c} = \sum_{t=1}^{\tau}(\hat{y}^t-y^t) cL=t=1τcLt=t=1τ(y^tyt)

    按照刘的说法,如果是softmax的激活函数,那么这里的 c c c应该是向量,但是在他的文章中, c c c的符号是标量符号。主要是如果按照softmax来进行推导,得不到后续的公式,这里暂且先按照sigmoid函数来。 ∂ L t ∂ c = ∂ L t ∂ y ^ t ⋅ ∂ y ^ t ∂ o t ⋅ ∂ o t ∂ c ∂ L t ∂ y ^ t = − ∂ ∂ y ^ t ( y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) ) = − y t y ^ t + 1 − y t 1 − y ^ t ∂ y ^ t ∂ o t = ∂ ∂ o t ( s i g m o i d ( o t ) ) = s i g m o i d ( o t ) ( 1 − s i g m o i d ( o t ) ) = y ^ t ( 1 − y ^ t ) ∂ o t ∂ c = ∂ ∂ c ( V h t + c ) = 1 \frac{\partial L^t}{\partial c} = \frac{\partial L^t}{\partial \hat{y}^t}\cdot \frac{\partial \hat{y}^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial c}\\\frac{\partial L^t}{\partial \hat{y}^t} =-\frac{\partial }{\partial \hat{y}^t}(y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t))\\=-\frac{y^t}{\hat{y}^t}+\frac{1-y^t}{1-\hat{y}^t}\\\frac{\partial \hat{y}^t}{\partial o^t}=\frac{\partial}{\partial o^t}(sigmoid(o^t))\\=sigmoid(o^t)(1-sigmoid(o^t))\\=\hat{y}^t(1-\hat{y}^t)\\\frac{\partial o^t}{\partial c} = \frac{\partial}{\partial c}(Vh^t+c)\\=1 cLt=y^tLtoty^tcoty^tLt=y^t(ytlogy^t+(1yt)log(1y^t))=y^tyt+1y^t1ytoty^t=ot(sigmoid(ot))=sigmoid(ot)(1sigmoid(ot))=y^t(1y^t)cot=c(Vht+c)=1 检查一下第一个表达式,由于每个变量都是标量,所以可以按照标量的链式求导法则来求导。把每个表达式的值代入,发现的确如此。由上面的推导,还可以得到: ∂ L t ∂ o t = ∂ L t ∂ c (1) \frac{\partial L^t}{\partial o^t} = \frac{\partial L^t}{\partial c}\tag{1} otLt=cLt(1)

    ∂ L ∂ V = ∑ t = 1 τ ∂ L t ∂ V = ∑ t = 1 τ ( y ^ t − y t ) ( h t ) T (2) \frac{\partial L}{\partial V} = \sum_{t=1}^\tau \frac{\partial L^t}{\partial V} = \sum_{t=1}^\tau(\hat{y}^t-y^t)(h^t)^T\tag{2} VL=t=1τVLt=t=1τ(y^tyt)(ht)T(2)

    其中, L ∈ R , V ∈ R 1 × m , h t ∈ R m L\in \bold{R}, V\in\bold{R}^{1\times m}, h^t\in \bold{R}^m LR,VR1×m,htRm,注意到,这里涉及到标量对向量的求导,采用分母布局,注意检查等号两边的维度是否相同,参与运算的变量保证能够进行矩阵相乘,必要的时候需要调整位置以便能完成相应的矩阵乘法。公式2的推导很简单,因为 ∂ L t ∂ V = ∂ L t ∂ o t ⋅ ∂ o t ∂ V \frac{\partial L^t}{\partial V} = \frac{\partial L^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial V} VLt=otLtVot

    接下来就是 W , U , b W,U,b W,U,b的梯度计算了,这三者的梯度计算是相对复杂的。从RNN的结构可以知道,反向传播时,在某个时刻t的梯度损失由当前位置的输出对应的梯度损失和 t + 1 t+1 t+1时刻的梯度损失两部分共同决定,而 t + 1 t+1 t+1时刻的梯度损失有相同的结构,可以看出是循环嵌套的。因此 W W W在某一位置t的梯度损失需要一步步计算。我们定义序列索引 t t t的隐藏状态的梯度为: δ t = ∂ L ∂ h t (3) \delta^t = \frac{\partial L}{\partial h^t}\tag{3} δt=htL(3) 注意到公式3也是标量对向量的导数

    这样我们可以像DNN一样从 δ t + 1 \delta^{t+1} δt+1递推 δ t \delta^t δt δ t = ∂ L ∂ o t ⋅ ∂ o t ∂ h t + ( ∂ h t + 1 ∂ h t ) T ⋅ ∂ L ∂ h t + 1 = V T ∑ t = 1 τ ( y ^ t − y t ) + W T d i a g ( 1 − ( h t + 1 ) 2 ) δ t + 1 (4) \delta^t = \frac{\partial L}{\partial o^t}\cdot\frac{\partial o^t}{\partial h^t}+(\frac{\partial h^{t+1}}{\partial h^t})^T\cdot\frac{\partial L}{\partial h^{t+1}}\\=V^T\sum_{t=1}^{\tau}(\hat{y}^t-y^t)+W^Tdiag(1-(h^{t+1})^2)\delta^{t+1}\tag{4} δt=otLhtot+(htht+1)Tht+1L=VTt=1τ(y^tyt)+WTdiag(1(ht+1)2)δt+1(4) 公式4和刘的表达式不一样,个人认为我的应该是对的,刘的公式按照向量的求导法则,表达式中的维度不一致。第一步参考刘的矩阵微分系列博客。第二步中的第二部分,第一次看的时候没有明白,也花了挺多时间推导,这里记录一下。 h t + 1 = t a n h ( W h t + U x t + 1 + b ) h^{t+1} = tanh(Wh^t+Ux^{t+1}+b) ht+1=tanh(Wht+Uxt+1+b) 其中, W ∈ R m × m , x t ∈ R n , U ∈ R m × n W\in \bold{R}^{m\times m}, x^t\in \bold{R}^n,U\in\bold{R}^{m\times n} WRm×m,xtRn,URm×n t a n h ′ ( x ) = 1 − ( t a n h ( x ) ) 2 tanh^{'}(x)=1-(tanh(x))^2 tanh(x)=1(tanh(x))2

    ∂ h t + 1 ∂ h t \frac{\partial h^{t+1}}{\partial h^t} htht+1,这是向量对向量的求导,按照分子布局求导结果的维度是 m × m m\times m m×m。这里我们按照定义来求: ∂ h i t + 1 ∂ h t \frac{\partial h_i^{t+1}}{\partial h^t} hthit+1 此时变成了标量对向量的求导,按照分母布局,结果维度应该和 h t h^t ht相同,此时 h i t + 1 = t a n h ( W i , : h t ) h_i^{t+1} = tanh(W_{i,:}h^t) hit+1=tanh(Wi,:ht) 省略了与 h h h无关的项。那么: ∂ h i t + 1 ∂ h t = ( 1 − ( h i t + 1 ) 2 ) ∂ ∂ h t ( W i , : h t ) = ( 1 − ( h i t + 1 ) 2 ) W i , : T \frac{\partial h_i^{t+1}}{\partial h^t} = (1-(h_i^{t+1})^2)\frac{\partial }{\partial h^t}(W_{i,:}h^t)\\=(1-(h_i^{t+1})^2)W_{i,:}^T hthit+1=(1(hit+1)2)ht(Wi,:ht)=(1(hit+1)2)Wi,:T 此时结果的维度是 m × 1 m\times 1 m×1,由于是按照分子布局, i i i对应最后矩阵的第 i i i行,所以这里应该在转置一下,变成: ( 1 − ( h i t + 1 ) 2 ) W i , : (1-(h_i^{t+1})^2)W_{i,:} (1(hit+1)2)Wi,:

    所以: ∂ h t + 1 ∂ h t = d i a g ( 1 − ( h t + 1 ) 2 ) W \frac{\partial h^{t+1}}{\partial h^t}=diag(1-(h^{t+1})^2)W htht+1=diag(1(ht+1)2)W

    其中, d i a g ( 1 − ( h t + 1 ) 2 ) diag(1-(h^{t+1})^2) diag(1(ht+1)2) indicates the diagonal matrix containing the elements 1 − ( h i t + 1 ) 2 1-(h_i^{t+1})^2 1(hit+1)2(来自花书英文版385页)。

    将其代入公式4,就明白为什么是那样的表达式了。

    这里在记录一下另一个点, t a n h ( W h t ) tanh(Wh^t) tanh(Wht) 这是一个向量,有: ∂ ∂ W h t ( t a n h ( W h t ) ) = d i a g ( 1 − ( t a n h ( W h t ) ) 2 ) \frac{\partial }{\partial Wh^t}(tanh(Wh^t))=diag(1-(tanh(Wh^t))^2) Wht(tanh(Wht))=diag(1(tanh(Wht))2) 其实这就是向量对向量的求导,按照分子布局求导结果为 m × m m\times m m×m的矩阵,刚好对角矩阵是一个 m × m m\times m m×m的矩阵,这间接说明了等式的正确性。

    在刘的https://www.cnblogs.com/pinard/p/10825264.html第三节的最后部分,给出了四个非常重要的表达式,这里记录下接下来会用到的一个表达式: z = f ( y ) , y = X a + b − > ∂ z ∂ X = ∂ z ∂ y a T (5) z = f(y), y = Xa+b \quad-> \frac{\partial z}{\partial X}=\frac{\partial z}{\partial y}a^T\tag{5} z=f(y),y=Xa+b>Xz=yzaT(5) 其中, z z z为标量, y , a , b y,a,b y,a,b为向量, X X X为矩阵。不过发现好像用不到。。。

    有了 δ t \delta^t δt的表达式后,我们求 W , U , b W,U,b W,U,b就方便很多了,有 ∂ L ∂ W = ∑ t = 1 τ d i a g ( 1 − ( h t ) 2 ) δ t ( h t − 1 ) T ∂ L ∂ W = ∂ h τ ∂ W ∂ L ∂ h τ + . . . + ∂ h 1 ∂ W ∂ L ∂ h 1 = ∑ t = 1 τ ∂ h t ∂ W ∂ L ∂ h t (6) \frac{\partial L}{\partial W} = \sum_{t=1}^{\tau}diag(1-(h^t)^2)\delta^t(h^{t-1})^T\\\frac{\partial L}{\partial W} = \frac{\partial h^\tau}{\partial W}\frac{\partial L}{\partial h^\tau}+...+\frac{\partial h^1}{\partial W}\frac{\partial L}{\partial h^1}\\=\sum_{t=1}^\tau\frac{\partial h^t}{\partial W}\frac{\partial L}{\partial h^t}\\\tag{6} WL=t=1τdiag(1(ht)2)δt(ht1)TWL=WhτhτL+...+Wh1h1L=t=1τWhthtL(6) 老实说,这一步我不确定是否正确,因为这涉及到了向量对矩阵的导数,这玩意儿我还不会。不过这一步解释了为什么公式6里面有个累加符号。后续推导就更不会了,会的大佬请不吝赐教。不过感觉上好像是这么回事。只要这一步搞懂了,对 U , b U,b U,b的求导是类似的,可以参考刘的博客,这里就不写了。(太菜了,关键步骤不会)

    从表达式4可以看出RNN梯度消失和梯度爆炸的根本原因(展开后就能知道为什么了)。 后续补充LSTM的BPTT。(毕竟面试的时候会要求公式层面的对LSTM防止梯度消失和梯度爆炸的理解)

    Processed: 0.044, SQL: 9