[Reading] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (2015/2)

1. 概述

  文章提出了一种通过对网络各层输入的每个小批量(mini-batch)进行规范化,来解决内部协变量偏移(internal covariate shift)的方法,称为批量规范化(Batch Normalization,BN)。通过 BN,可以在训练过程中使用更高的学习率来加速训练,且能够降低模型参数初始化对性能的影响,还能起到一定的正则化效果。

  对于随机优化方法,BN 以小批量为单位进行规范化。为了保持网络的表达能力,BN 为每个激活值引入了两个额外的参数,在训练中学习。使用 BN 后,网络使用会饱和的激活函数也能够进行训练,可以使用更高的学习率,且不再需要 dropout 等正则化方法。在当时的 SOTA 图像分类模型中加入 BN,使用更少的训练步数就可以获得更好的性能。

2. 内部协变量偏移

  对于一个学习系统,如果其输入的分布发生了变化,就称其出现了协变量偏移(covariate shift)问题。此时系统的训练数据和测试数据来自不同的分布,从而影响系统在测试时的性能。类似的问题也存在于在系统内部,如在子网络或层中,我们不仅希望训练数据和测试数据来自相同的分布,还希望训练过程中输入数据本身的分布保持不变,这样学习的参数就不需要去不断适应输入数据分布的变化。

  此外,对于如 sigmoid 这类有饱和区的激活函数,我们也希望输入的数据在适当的范围之内,避免因激活函数饱和出现梯度消失。虽然此类问题可以通过使用如 ReLU 之类的激活函数、适当的初始化和较小的学习率来应对,但如果能保证训练过程中非线性输入的稳定性,优化器就可以更容易地进行优化,从而加速训练。

  文章将深度网络训练过程中,因网络参数变化而导致的激活值分布的变化,称为内部协变量偏移(internal covariate shift)。文章希望通过降低内部协变量偏移,来提升训练性能。一种方法是使输入具有固定的分布,例如将输入白化(whitening),即通过线性变换让输入具有零均值和单位方差,且去除相关。

  如果将规范化和训练作为两个独立的步骤,并不能达到期望的效果。例如对于某层,输入为 $u$,只有一个偏置参数 $b$,输出为 $x = u + b$,使用减去均值的方式进行规范化:

$$
\hat{x} = x – \mathrm{E}[x] = u + b – \mathrm{E}[u + b]
$$

其中 $\mathrm{E}[x] = \frac{1}{N} \sum\limits_{i=1}^N x_i$。之后通过优化更新参数 $b \leftarrow b + \Delta b$,在下一个迭代的输出为 $u + (b + \Delta b)$,进行归一化得到:

$$
u + (b + \Delta b) – \mathrm{E}[u + (b + \Delta b)] = u + b – \mathrm{E}[u + b]
$$

其中 $\Delta b$ 在本轮迭代是常数,有 $\mathrm{E}[\Delta b] = \Delta b$。可见这个结果与优化前的 $\hat{x}$ 相同,输出和损失都没有变化,而 $b$ 由于进行了更新,会越来越大。文章也在实验中验证了这一现象。

  上面方法的问题是,在优化时并没有将规范化考虑进去。为了解决此问题,需要让网络能直接输出具有期望分布的激活值,无论参数如何变化。记网络中某层的输入为 $\mathrm{x}$,训练集中所有样本的输入为 $\mathcal{X}$,此时归一化可以写成如下的变换:

$$
\hat{\mathrm{x}} = \mathrm{Norm}(\mathrm{x}, \mathcal{X})
$$

其中 $\hat{\mathrm{x}}$ 不仅依赖于当前训练样本 $\mathrm{x}$,还依赖于所有样本 $\mathcal{X}$。对于非输入层,$\hat{\mathrm{x}}$ 还会依赖于之前层的参数。反向传播时,就需要计算

$$
\frac{\partial\mathrm{Norm}(\mathrm{x}, \mathcal{X})}{\partial{\mathrm{x}}} \; 和 \; \frac{\partial\mathrm{Norm}(\mathrm{x}, \mathcal{X})}{\partial{\mathcal{X}}}
$$

上述计算是比较复杂的,而且涉及到整个训练集 $\mathcal{X}$,计算量也很大。

3. 批量规范化

3.1. 通过小批量进行规范化

  由于对每层的输入进行完整的白化计算量很大,且不是处处可微,为了得到更实用的方法,文章进行了两点简化。

  第一点简化是,不再对每层的输入或输出的各特征进行联合白化,而是独立地对每个标量特征进行规范化,使其具有零均值和单位方差。例如对于 $d$ 维的输入 $\mathrm{x} = (x^{(1)} \dots x^{(d)})$,对每个维度进行如下的规范化:

$$
\hat{x}^{(k)} = \frac{x^{(k)} – \mathrm{E}[x^{(k)}]}{\sqrt{\mathrm{Var}[x^{(k)}]}}
$$

上式中的期望和方差是在训练上计算得到的,虽然没有去除相关,但仍然可以加速收敛。

  需要注意的是,规范化会影响层的表达能力,例如对于 sigmoid 激活函数,规范化的输入会落在线性区间,无法起到引入非线性的作用。因此,需要保证引入的“规范化”变换可以表示恒等变换,网络在训练过程中,可以学习合适的变换方式来进行“规范化”。为此,为每一个激活值 $x^{(k)}$ 引入一组参数 $\gamma^{(k)}$ 和 $\beta^{(k)}$,对规范化的结果进行缩放和偏移的线性变换:

$$
y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}
$$

其中 $\gamma^{(k)}$ 和 $\beta^{(k)}$ 作为额外的参数,在原模型训练过程中习得。注意到当 $\gamma^{(k)} = \sqrt{\mathrm{Var}[x^{(k)}]}$ 且 $\beta^{(k)} = \mathrm{E}[x^{(k)}]$ 时,有 $y^{(k)} = x^{(k)}$,也就恢复了规范化前的激活值。

  第二点简化是,不再在完整的训练集上进行规范化。对于如随机梯度下降之类的随机优化方法,每次只在一小部分数据上计算梯度并更新参数,并不会接触到完整的训练集。如果训练使用了小批量的方法,则在每个小批量上计算规范化。对于尺寸为 $m$ 的小批量

$$
\mathcal{B} = \{x_{1 \dots m}\}
$$

记规范化后的结果为 $\hat{x}_{1 \dots m}$,线性变换后的结果为 $y_{1 \dots m}$,则将 BN 变换记为

$$
\mathrm{BN}_{\gamma, \beta}: x_{1 \dots m} \rightarrow y_{1 \dots m}
$$

具体算法为:

Algorithm 1
Algorithm 1

其中 $\epsilon$ 是一个小常数,用来保证数值稳定性。

3.2. 训练时的 BN

  要在网络中使用 BN,只需将层的输入 $x$ 替换为 $\mathrm{BN}(x)$ 即可,之后就可以使用批量梯度下降,或者小批量的随机梯度下降,或者如 Adagrad 的变种来进行训练。训练算法为:

Algorithm 2
Algorithm 2

3.3. 推断时的 BN

  BN 主要是用来提高训练效率,在推断时,使用训练样本总体的期望和方差来对测试数据进行规范化:

$$
\hat{x} = \frac{x – \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}}
$$

其中的方差使用无偏估计 $\mathrm{Var}[x] = \frac{m}{m-1} \mathrm{E}_{\mathcal{B}}[\sigma_{\mathcal{B}}^2]$,$\mathrm{E}_{\mathcal{B}}$ 和 $\sigma_{\mathcal{B}}^2$ 是小批量的期望和方差。在典型实现中,会使用各个小批量的均值和方差的指数加权平均来估计测试中使用的均值和方差。

  在推断时,期望和方差都是固定的,因此 BN 只是对激活值的线性变换。

3.4. BN 的位置

  对于仿射变换加非线性的层结构,记其输入为 $\mathrm{u}$,权重为 $W$,偏置为 $\mathrm{b}$,激活函数为 $g(\cdot)$,则输出可以表示为

$$
\mathrm{z} = g(W \mathrm{u} + \mathrm{b})
$$

一种方式是对输入 $\mathrm{u}$ 进行 BN,但由于 $\mathrm{u}$ 一般是上一层非线性的输出,其分布在训练过程中可能会改变,仅通过前面减去均值再除以标准差的方式,并不能消除内部协变量偏移的问题。

  相比之下,$W \mathrm{u} + \mathrm{b}$ 更容易具有对称、非稀疏的分布,更加的“高斯”,对其进行规范化更容易获得稳定分布的激活值,因此文章将 BN 应用在此处。

  由于规范化时要减去均值,全连接层中的偏置会被减掉,而 BN 的参数 $\beta$ 扮演了偏置的角色,因此可以将全连接层中的偏置去掉,此时全连接层的输出为

$$
\mathrm{z} = g\big(\mathrm{BN}(W \mathrm{u})\big)
$$

  对于卷积层,相同的过滤器会作用在输入的各个位置上,因此文章以特征图为单位进行规范化,在学习时为每个特征图学习参数 $\gamma^{k}$ 和 $\beta^{k}$,在推断时为每个特征图进行相同的线性变换。

3.5. BN 的效果

  在传统的深度网络中,过高的学习率可能会导致梯度消失或爆炸,或者收敛到局部最优。BN 通过对激活值进行规范化,避免参数的微小变化随着层数的增加而被不断放大。此外 BN 还可以适应不同的参数范围,过高的学习率通常会让参数变得更大,而使用 BN 后,对于标量 $a$,有

$$
\mathrm{BN}(W \mathrm{u}) = \mathrm{BN}\big((a W) \mathrm{u}\big)
$$

且有

$$
\frac{\partial \mathrm{BN}\big((a W) \mathrm{u}\big)}{\partial \mathrm{u}} = \frac{\partial \mathrm{BN}(W \mathrm{u})}{\partial \mathrm{u}}
$$

$$
\frac{\partial \mathrm{BN}\big((a W) \mathrm{u}\big)}{\partial (a W)} = \frac{1}{a} \cdot \frac{\partial \mathrm{BN}(W \mathrm{u})}{\partial W}
$$

可见使用标量 $a$ 对权重进行缩放,不会影响这一层的雅可比矩阵,也不会影响梯度的反向传播,且更大的权重会导致更小的梯度,使得参数的增长更加稳定。

  此外 BN 还可以起到正则化的效果。通过 BN,每个训练样本都引入了小批量中其他样本的信息,从而提高网络的泛化能力。

4. 实验结果

  文章首先使用一个简单的网络比较了 BN 在 MNIST 上的性能,如 Figure 1 所示。由 Figure 1 (a) 可见使用 BN 的网路可以更快地达到更高的准确率,由 Figure 1 (b)、(c) 可见使用 BN 的网络的激活值具有更稳定的分布。

Figure 1
Figure 1

  文章将 BN 应用于 Inception 网络,比较了不同变种在 ImageNet 图像分类的性能,包括

  • Inception
  • BN-Baseline:在 Inception 的每个非线性前加入 BN
  • BN-x5:使用 5 倍学习率(0.0075)的 BN-Baseline
  • BN-x30:使用 30 倍学习率(0.045)的 BN-Baseline
  • BN-x5-Sigmoid:在 BN-x5 使用 sigmoid 替换 ReLU 作为非线性函数

结果如 Figure 2、3所示。可见使用 BN 可以大幅提升训练速度,并能达到更高的准确率(除了 BN-x5-Sigmoid)。注意到 BN-x30 在训练初始阶段要差于 BN-x5,但最终收敛到了更高的准确率。即便使用更难训练的 sigmoid,BN 仍能让网络达到一定的性能,而在原 Inception 中使用 sigmoid 则会导致准确率低于 1/1000。

Figure 2
Figure 2
Figure 3
Figure 3

  使用模型融合后,进一步获得了更好的性能,如 Figure 4 所示。

Figure 4
Figure 4