欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

透彻掌握Batch-Normalization的前向传播(Forward Pass)、反向传播(Backward Pass)及其实现代码详解

最编程 2024-07-24 21:57:45
...

深入理解 Batch-Normalization

BN 能显著提升神经网络模型的训练速度(论文),自2015年被推出以来,已经成为神经网络模型的标准层。

现代深度学习框架(如 TF、Pytorch 等)均内置了 BN 层,使得我们在搭建网络轻而易举。但这也间接造成很多人对于 BN 的理解只停留在 概念 层面,而没有深入公式,详细推导其行为 (前向传播+反向传播)。

本文的主旨则是从数学公式层面,详细推导 BN,并通过代码手动实现BN 层。


一、BN 的 前向传播

让我们从原论文中最出名的一张图开始吧:

BN 的前向传播
图1BN 的前向传播)

BN的前向传播过程分别在不同阶段的行为可以概述如下:

训练阶段:

  • 对每个批次的输入 x[ ‼️重要:在batch 方向上‼️],计算 均值 μ B {\mu}_B μB方差 σ B 2 {\sigma}^2_B σB2:
    • μ B = 1 m ∑ i m x i {\mu}_B = \frac{1}{m} \sum_i^m{x_i} μB=m1imxi
    • σ B 2 = 1 m ∑ i m ( x i − μ B ) 2 {\sigma}^2_B = \frac{1}{m} \sum_i^m{{(x_i - {\mu}_B)}^2} σB2=m1im(xiμB)2
  • 利用 μ B {\mu}_B μB σ B 2 {\sigma}^2_B σB2 对输入 x 进行标准化:
    • x i ^ = x i − μ B σ B 2 + ϵ \hat{x_i} = \frac{x_i - \mu_B}{\sqrt{{\sigma}^2_B + \epsilon}} xi^=σB2+ϵ xiμB
  • 引入可学习参数 γ \gamma γ β \beta β, 对标准化后的 x i ^ \hat{x_i} xi^ 进行 缩放平移,作为 BN 层的最终输出值:
    • y i = γ x i ^ + β y_i=\gamma\hat{x_i}+\beta yi=γxi^+β

注意:
训练过程 中会以指数平均的方式计算整个训练集的 平均均值(running mean)平均方差(running_var),这两个值将在 测试阶段 代替 μ B {\mu}_B μB σ B 2 {\sigma}^2_B σB2x 进行归一化:

  • r u n n i n g _ m e a n = m o m e n t u m ∗ r u n n i n g _ m e a n + ( 1 − m o m e n t u m ) ∗ μ B running\_mean=momentum * running\_mean + (1-momentum)*\mu_B running_mean=momentumrunning_mean+(1momentum)μB
  • r u n n i n g _ v a r = m o m e n u t m ∗ r u n n i n g _ v a r + ( 1 − m o m e n t u m ) ∗ σ B 2 running\_var=momenutm * running\_var + (1-momentum)*\sigma^2_B running_var=momenutmrunning_var+(1momentum)σB2

测试阶段
在这个阶段的计算流程大体与训练阶段相同,但不会计算 μ B {\mu}_B μB σ B 2 {\sigma}^2_B σB2,而是分别以 running_meanrunning_var 代替。

说明:

  • 对于 Linear 层,设 x 的维度为 [N, D];那么上面那些公式中的值都是什么维度?
    • μ B {\mu}_B μB σ B 2 {\sigma}^2_B σB2[D]
    • x i ^ \hat{x_i} xi^ y i y_i yi: [N,D]
    • running_meanrunning_var: [D]
    • γ \gamma γ β \beta β: [D]
  • 如果是Conv 层,设 x 的维度为 [N, C, H, W]; 那么上面那些公式中的值都是什么维度?
    • 这种情况要特别注意⚠️,对于卷基层,BN 计算均值和方差将会考虑 HW 的维度,在 Pytorch 中称为 BatchNorm2D,如下图所示:
      Conv BN
      图2: BatchNorm2D)

二、BN 的 反向传播

反向传播的要点是找到 Loss 对当前节点中所有参数的梯度以及对节点的输入张量 x 的梯度,即 ∂ L ∂ γ \frac{\partial L}{\partial \gamma} γL ∂ L ∂ β \frac{\partial L}{\partial \beta} βL 以及 ∂ L ∂ x \frac{\partial L}{\partial x} xL

链式法则可知,这些梯度均等于 上游梯度 * 局部梯度