欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

终于找到了批量归一化的反向传播公式!

最编程 2024-05-06 16:55:51
...

最近在做CS231n的Assignment2,需要推导Batch Normalization的反向传播公式并用代码实现。自己试着用链式法则一步步求,最终得出来的式子巨复杂,而且没有求和符号,很不对劲。去看原论文给出的公式也是一头雾水:

原论文的反向传播公式.png

参考了
https://www.adityaagrawal.net/blog/deep_learning/bprop_batch_norm
https://kevinzakka.github.io/2016/09/14/batch_normalization/
但里面对x_i求导都是直接用
\frac{\partial l}{\partial x_i}=\frac{\partial l}{\partial \hat{x_i}}*\frac{\partial \hat{x_i}}{\partial x_i} + \frac{\partial l}{\partial \mu}*\frac{\partial \mu}{\partial x_i} + \frac{\partial l}{\partial \sigma}*\frac{\partial \sigma}{\partial x_i}
通过画出计算图不难理解这个式子,但怎么通过链式法则公式本身推导出来呢?自己的推导过程错在哪里呢?研究了两天,终于搞懂了!接下来详细记录一下求导过程。

复习:多元复合函数的求导法则

根据《高等数学》(同济大学第七版)下册第九章第四节,多元函数与多元函数复合的情形有如下定理:

多元复合函数的链式法则

用通俗的话说,就是需要对所有跟x有关的复合函数施以链式法则并求和。记住这一点,我们撸起袖子开干。

需解决的问题

Batch Normalization向前传播时各式的定义为:
\begin{align} \tag{1}\mu &= \frac{1}{m}\sum_{i=1}^{m}x_i \\ \tag{2}\sigma^2 &= \frac{1}{m}\sum_{i=1}^{m}(x_i-\mu)^2\\ \tag{3}\hat{x_i} &=\frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}\\ \tag{4}y_i &= \gamma\hat{x_i} + \beta \end{align}反向传播时,从上游传下来\frac{\partial l}{\partial y_i},求\frac{\partial l}{\partial x_i},\frac{\partial l}{\partial \gamma},\frac{\partial l}{\partial \beta}

推导过程

先求比较简单的:
\begin{align} \frac{\partial l}{\partial \gamma} &= \frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial \gamma} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial \gamma} + ... +\frac{\partial l}{\partial y_m} \frac{\partial y_m}{\partial \gamma}\\ \tag{5} &=\sum_{i=1}^m\frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \gamma}\\ &=\sum_{i=1}^m\frac{\partial l}{\partial y_i}\hat{x_i}\\ \\ \frac{\partial l}{\partial \beta} &= \frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial \beta} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial \beta} + ... +\frac{\partial l}{\partial y_m} \frac{\partial y_m}{\partial \beta}\\ \tag{6} &=\sum_{i=1}^m\frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \beta}\\ &=\sum_{i=1}^m\frac{\partial l}{\partial y_i}\\ \\ \tag{7}\frac{\partial l}{\partial \hat{x_i}} &= \frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \hat{x_i}}\\ &=\frac{\partial l}{\partial y_i} \gamma\\ \end{align}
根据式(4),所有y_i算式中都有\gamma\beta,所以需要对每个y_i进行链式法则求导并加和。同理,所有\hat{x_i}算式中都有\mu\sigma,且\sigma\mu的函数,根据多元复合函数的求导法则,有:
\begin{align} \frac{\partial l}{\partial \sigma^2} &= \frac{\partial l}{\partial y_1}\frac{\partial y_1}{\partial \hat{x_1}} \frac{\partial \hat{x_1}}{\partial \sigma^2} +\frac{\partial l}{\partial y_2}\frac{\partial y_2}{\partial \hat{x_2}} \frac{\partial \hat{x_2}}{\partial \sigma^2} + ... +\frac{\partial l}{\partial y_m}\frac{\partial y_m}{\partial \hat{x_m}} \frac{\partial \hat{x_m}}{\partial \sigma^2} \\ \tag{8} &=\sum_{i=1}^m\frac{\partial l}{\partial y_i}\frac{\partial y_i}{\partial \hat{x_i}} \frac{\partial \hat{x_i}}{\partial \sigma^2}\\ &\xlongequal{式(7)}\sum_{i=1}^m\frac{\partial l}{\partial \hat{x_i}}\frac{\partial \hat{x_i}}{\partial \sigma^2}\\ &=\sum_{i=1}^m\frac{\partial l}{\partial \hat{x_i}}(x_i - \mu)(-\frac{1}{2})(\sigma^2+\epsilon)^{-\frac{3}{2}} \end{align}\begin{align} \frac{\partial l}{\partial \mu} &= \frac{\partial l}{\partial y_1}\frac{\partial y_1}{\partial \hat{x_1}} \frac{\partial \hat{x_1}}{\partial \mu} +\frac{\partial l}{\partial y_2}\frac{\partial y_2}{\partial \hat{x_2}} \frac{\partial \hat{x_2}}{\partial \mu} + ... +\frac{\partial l}{\partial y_m}\frac{\partial y_m}{\partial \hat{x_m}} \frac{\partial \hat{x_m}}{\partial \mu} \\ &+ \frac{\partial l}{\partial y_1}\frac{\partial y_1}{\partial \hat{x_1}} \frac{\partial \hat{x_1}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial \mu} +\frac{\partial l}{\partial y_2}\frac{\partial y_2}{\partial \hat{x_2}} \frac{\partial \hat{x_2}}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial \mu} + ... +\frac{\partial l}{\partial y_m}\frac{\partial y_m}{\partial \hat{x_m}} \frac{\partial \hat{x_m}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial \mu} \\ \tag{9} &=\sum_{i=1}^m\frac{\partial l}{\partial y_i} \frac{\partial y_i}{\partial \hat{x_i}} (\frac{\partial \hat{x_i}}{\partial \mu} + \frac{\partial \hat{x_i}}{\partial \sigma^2} \boxed{\frac{\partial \sigma^2}{\partial \mu}} )\\ &\xlongequal{式(7)(10)} \sum_{i=1}^m \frac{\partial l}{\partial \hat{x_i}} \frac{\partial \hat{x_i}}{\partial \mu} \\ &= \sum_{i=1}^m \frac{\partial l}{\partial \hat{x_i}} \frac{-1}{\sqrt{\sigma^2 + \epsilon}} \end{align}
其中,\begin{align} \tag{10} \frac{\partial \sigma^2}{\partial \mu} &= -\frac{2}{m} \sum_{i=1}^m(x_i - \mu) = -\frac{2}{m}(\sum_{i=1}^m x_i - m \mu) \xlongequal{式(1)} 0 \end{align} 参考下图便于理解,其中有彩色的关系都需要有求和符号:

链式求导关系图

最后,将式(3)看成是u(x_j)=x_j, \mu, \sigma^2的复合函数,则有 \begin{align} \frac{\partial l}{\partial x_j} = & \frac{\partial l}{\partial y_j}\frac{\partial y_j}{\partial \hat{x_j}} \frac{\partial \hat{x_j}}{\partial x_j} \\ &+ \frac{\partial l}{\partial y_1}\frac{\partial y_1}{\partial \hat{x_1}} \frac{\partial \hat{x_1}}{\partial \mu} \frac{\partial \mu}{\partial x_j}+\frac{\partial l}{\partial y_2}\frac{\partial y_2}{\partial \hat{x_2}} \frac{\partial \hat{x_2}}{\partial \mu} \frac{\partial \mu}{\partial x_j} + ... +\frac{\partial l}{\partial y_m}\frac{\partial y_m}{\partial \hat{x_m}} \frac{\partial \hat{x_m}}{\partial \mu} \frac{\partial \mu}{\partial x_j}\\ &+ \frac{\partial l}{\partial y_1}\frac{\partial y_1}{\partial \hat{x_1}} \frac{\partial \hat{x_1}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial x_j}+\frac{\partial l}{\partial y_2}\frac{\partial y_2}{\partial \hat{x_2}} \frac{\partial \hat{x_2}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial x_j} + ... +\frac{\partial l}{\partial y_m}\frac{\partial y_m}{\partial \hat{x_m}} \frac{\partial \hat{x_m}}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial x_j}\\ =& \frac{\partial l}{\partial y_j}\frac{\partial y_j}{\partial \hat{x_j}} \frac{\partial \hat{x_j}}{\partial x_j} \\ &+ (\sum_{i=1}^m \frac{\partial l}{\partial y_i}\frac{\partial y_i}{\partial \hat{x_i}} \frac{\partial \hat{x_i}}{\partial \mu}) \frac{\partial \mu}{\partial x_j} \\ &+ (\sum_{i=1}^m \frac{\partial l}{\partial y_i}\frac{\partial y_i}{\partial \hat{x_i}} \frac{\partial \hat{x_i}}{\partial \sigma^2}) \frac{\partial \sigma^2}{\partial x_j}\\ \xlongequal{式(7)(8)(9)}& \frac{\partial l}{\partial \hat{x_j}} \frac{\partial \hat{x_j}}{\partial x_j} + \frac{\partial l}{\partial \mu} \frac{\partial \mu}{\partial x_j} + \frac{\partial l}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial x_j} \\ =& \frac{\partial l}{\partial \hat{x_j}} \frac{1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial l}{\partial \mu} \frac{1}{m} + \frac{\partial l}{\partial \sigma^2} \frac{2(x_j - \mu)}{m} \end{align}
将式(7)(8)(9)的最终结果全部代入上式即可得:
\frac{\partial l}{\partial x_j} = \frac{1}{m\sqrt{\sigma^2+\epsilon}}(m\frac{\partial l}{\partial \hat{x_j}} - \sum_{i=1}^m\frac{\partial l}{\partial \hat{x_j}} - \hat{x_j}\sum_{i=1}^m\frac{\partial l}{\partial \hat{x_j}} \hat{x_j})