透彻掌握Batch-Normalization的前向传播(Forward Pass)、反向传播(Backward Pass)及其实现代码详解
深入理解 Batch-Normalization
BN 能显著提升神经网络模型的训练速度(论文),自2015年被推出以来,已经成为神经网络模型的标准层。
现代深度学习框架(如 TF、Pytorch 等)均内置了 BN 层,使得我们在搭建网络轻而易举。但这也间接造成很多人对于 BN 的理解只停留在 概念 层面,而没有深入公式,详细推导其行为 (前向传播+反向传播)。
本文的主旨则是从数学公式层面,详细推导 BN,并通过代码手动实现BN 层。
一、BN 的 前向传播
让我们从原论文中最出名的一张图开始吧:
(图1: BN 的前向传播)
BN的前向传播过程分别在不同阶段的行为可以概述如下:
训练阶段:
- 对每个批次的输入 x,[ ‼️重要:在batch 方向上‼️],计算 均值
μ
B
{\mu}_B
μB 和 方差
σ
B
2
{\sigma}^2_B
σB2:
- μ B = 1 m ∑ i m x i {\mu}_B = \frac{1}{m} \sum_i^m{x_i} μB=m1∑imxi
- σ B 2 = 1 m ∑ i m ( x i − μ B ) 2 {\sigma}^2_B = \frac{1}{m} \sum_i^m{{(x_i - {\mu}_B)}^2} σB2=m1∑im(xi−μB)2
- 利用
μ
B
{\mu}_B
μB 和
σ
B
2
{\sigma}^2_B
σB2 对输入 x 进行标准化:
- x i ^ = x i − μ B σ B 2 + ϵ \hat{x_i} = \frac{x_i - \mu_B}{\sqrt{{\sigma}^2_B + \epsilon}} xi^=σB2+ϵxi−μB
- 引入可学习参数
γ
\gamma
γ 和
β
\beta
β, 对标准化后的
x
i
^
\hat{x_i}
xi^ 进行 缩放 和 平移,作为 BN 层的最终输出值:
- y i = γ x i ^ + β y_i=\gamma\hat{x_i}+\beta yi=γxi^+β
注意:
训练过程 中会以指数平均的方式计算整个训练集的 平均均值(running mean) 和 平均方差(running_var),这两个值将在 测试阶段 代替 μ B {\mu}_B μB 和 σ B 2 {\sigma}^2_B σB2 对 x 进行归一化:
- r u n n i n g _ m e a n = m o m e n t u m ∗ r u n n i n g _ m e a n + ( 1 − m o m e n t u m ) ∗ μ B running\_mean=momentum * running\_mean + (1-momentum)*\mu_B running_mean=momentum∗running_mean+(1−momentum)∗μB
- r u n n i n g _ v a r = m o m e n u t m ∗ r u n n i n g _ v a r + ( 1 − m o m e n t u m ) ∗ σ B 2 running\_var=momenutm * running\_var + (1-momentum)*\sigma^2_B running_var=momenutm∗running_var+(1−momentum)∗σB2
测试阶段
在这个阶段的计算流程大体与训练阶段相同,但不会计算
μ
B
{\mu}_B
μB 和
σ
B
2
{\sigma}^2_B
σB2,而是分别以 running_mean 和 running_var 代替。
说明:
- 对于 Linear 层,设 x 的维度为 [N, D];那么上面那些公式中的值都是什么维度?
- μ B {\mu}_B μB 和 σ B 2 {\sigma}^2_B σB2: [D]
- x i ^ \hat{x_i} xi^ 和 y i y_i yi: [N,D]
- running_mean 和 running_var: [D]
- γ \gamma γ 和 β \beta β: [D]
- 如果是Conv 层,设 x 的维度为 [N, C, H, W]; 那么上面那些公式中的值都是什么维度?
- 这种情况要特别注意⚠️,对于卷基层,BN 计算均值和方差将会考虑 H 和 W 的维度,在 Pytorch 中称为 BatchNorm2D,如下图所示:
(图2: BatchNorm2D)
- 这种情况要特别注意⚠️,对于卷基层,BN 计算均值和方差将会考虑 H 和 W 的维度,在 Pytorch 中称为 BatchNorm2D,如下图所示:
二、BN 的 反向传播
反向传播的要点是找到 Loss 对当前节点中所有参数的梯度以及对节点的输入张量 x 的梯度,即 ∂ L ∂ γ \frac{\partial L}{\partial \gamma} ∂γ∂L、 ∂ L ∂ β \frac{\partial L}{\partial \beta} ∂β∂L 以及 ∂ L ∂ x \frac{\partial L}{\partial x} ∂x∂L。
由 链式法则可知,这些梯度均等于 上游梯度 * 局部梯度:
- ∂ L ∂ γ = ∂ L ∂ o u t ∗ ∂ o u t ∂ γ \frac {\partial L}{\partial \gamma}=\frac {\partial L}{\partial out}*\frac {\partial out}{\partial \gamma} ∂γ∂L=∂out∂L∗