[Reading] Spatial Transformer Networks

1. 概述

  Spatial Transformer Networks 一文提交于 2015/6,文章提出了一种对特征图进行空间变换的模块,称为 Spatial Transformer(ST)。该模块可以通过学习,对不同的特征图进行适当的变换,增强卷积神经网络对输入数据的空间不变性,从而提高网络性能。

  文章指出,虽然 CNN 在图像识别上获得了很好的性能,但依然无法有效地维持输入数据的空间不变性:无论物体是什么姿态、有何种形变,都不应影响对物体的识别。

  最大池化可以在一定程度上满足特征的空间不变性,但由于最大池化的尺寸通常较小,它只能在卷积网络的较深层上获得更大的感受野,从而实现空间不变性。对于网络中间的部分,如果输入数据有较大的形变等变化,则无法保持空间不变性。这一限制的本质在于,池化的机制是固定的,而输入数据的形态多种多样,固定的池化无法灵活应对输入数据的各种空间变化。

  文章提出的 ST 能够针对不同的输入样本进行不同的变换。ST 直接作用于整个特征图(而不是像池化只作用于局部),在前向传播中通过缩放、旋转、裁剪等变换,对输入数据进行“校正”,使得网络可以自主选择数据中最相关的部分(注意力机制),而且能对选择的区域进行“校正”,使其具有统一的标准“姿态”,来简化后续的识别,如 Figure 1 所示。ST 可以微分,能在训练过程通过反向传播学习如何进行变换,不需要额外的监督学习或修改优化流程,可以很容易地加入到已有网络中。

Figure 1

Figure 1

2. Spatial Transformer

  文章提出的空间变换分为三个步骤,如 Figure 2 所示:

  1. 输入的特征图通过定位网络(localisation network)得到空间变换的参数,这些参数特定于输入特征图;
  2. 网格生成器(grid generator)使用这些参数生成一个采样网格(sampling grid),即要对输入特征图的那些点进行采样;
  3. 采样器(sampler)使用采样网格对输入特征图进行采样。
Figure 2

Figure 2

2.1. 定位网络

  定位网络的输入为特征图 $U \in \mathbb{R}^{H \times W \times C}$,输出为变换的参数 $\theta$,$\theta = f_{\mathrm{loc}}(U)$。$\theta$ 的尺寸由变换的类型决定,例如仿射变换需要 6 个参数。$f_{\mathrm{loc}}()$ 的形式不限,可以使用全连接层或者卷积层,但最后都要有一个回归层来生成参数 $\theta$。

2.2. 参数化采样网格

  空间变换输出的每一个像素都是使用采样核在输入数据的特定位置上计算得到的。记输出像素位于网格 $G = {G_i}$,像素 $G_i = (x_i^t, y_i^t)$,输出的特征图 $V \in \mathbb{R}^{H’ \times W’ \times C’}$,假设要进行的变换 $\mathcal{T}_\theta$ 是一个二维仿射变换 $A_\theta$,即

\begin{equation}
\begin{pmatrix} x_i^s \\ y_i^s \end{pmatrix} = \mathcal{T}_\theta(G_i)
= A_\theta \begin{pmatrix} x_i^t \\ y_i^t \\ 1 \end{pmatrix}
= \begin{matrix} \theta_{11} & \theta_{12} & \theta_{13} \\ \theta_{21} & \theta_{22} & \theta_{23} \end{matrix}
\begin{pmatrix} x_i^t \\ y_i^t \\ 1 \end{pmatrix} \tag{1}
\end{equation}

其中 $(x_i^t, y_i^t)$ 是输出特征图的目标坐标, $(x_i^s, y_i^s)$ 是输入特征图的源坐标,$A_\theta$ 是仿射变换矩阵。这里使用的坐标都进行了标准化,有 $-1 \leq x_i^t, y_i^t \leq 1$,$-1 \leq x_i^s, y_i^s) \leq 1$。Figure 3 展示了恒等变换和放射变换的结果。

Figure 3

Figure 3

  注意这里是对目标坐标进行变换,得到源坐标:我们希望输出的特征图具有“标准”的姿态,位于规则的网格坐标系中;为了得到这样的结果,我们需要知道目标坐标系中各个位置上的值,是由源坐标系中的哪些值计算而来的,因此式 $(1)$ 的变换是从目标坐标(已知)到源坐标(未知)的。建立起这样的关系之后,就能按坐标将输入特征图变换为输出特征图。

  变换 $\mathcal{T}_\theta$ 不仅限于仿射变换,更一般地,有 $\mathcal{T}_\theta = M_\theta B$,其中 $B$ 为目标网格,$M_\theta$ 是一个由 $\theta$ 参数化的矩阵,可以同时对 $\theta$ 和 $B$ 进行学习和预测。

2.3. 可微图像采样

  空间变换最终输出 $V$ 的某个特定像素的值,由采样核(sampling kernel)根据对应采样点,即 $\mathcal{T}_\theta(G)$ 中的 $(x_i^s, y_i^s)$,对输入数据进行采样和插值处理得到(通过式 $(1)$ 计算得到的 $(x_i^s, y_i^s)$ 可能是小数,需要进行插值)。该流程可以写为

\begin{equation}
V_i^c = \sum_n^H \sum_m^W U_{nm}^c k(x_i^s – m; \Phi_x) k(y_i^s – n; \Phi_y) \quad \forall i \in [1 \dots H’W’] \quad \forall c in [1 \dots C] \tag{3}
\end{equation}

其中 $\Phi_x$ 和 $\Phi_y$ 是通用采样核 $k()$ 的参数,采样核定义了对图像进行插值的方法(如双线性)。$U_{nm}^c$ 是输入数据在位置 $(n, m)$ 、通道 $c$ 上的值,$V_i^c$ 是输出的第 $i$ 个像素的值,位于 $(x_i^t, y_i^t)$、通道 $c$ 。这里对每个通道都使用相同的采样方法,对每个通道进行的变换也是一样的,以此保持通道间的空间一致性。

  采样核需要对 $x_i^s$ 和 $y_i^s$ 的梯度有定义,例如使用整数采样核,此时式 $(3)$ 变为

\begin{equation}
V_i^c = \sum_n^H \sum_m^W U_{nm}^c \delta(\lfloor x_i^s + 0.5 \rfloor – m) \delta(\lfloor y_i^s + 0.5 \rfloor – n) \tag{4}
\end{equation}

其中 $\lfloor x_i^s + 0.5 \rfloor$ 将 $x$ 圆整到最近的整数,$\delta$ 为 Kronecker delta 函数。式 $(4)$ 所示的采样核相当于使用输入数据中距离 $(x_i^s, y_i^s)$ 最近的像素作为输出数据中位于 $(x_t^s, y_t^s)$ 的像素。

  此外也可以使用双线性采样核

\begin{equation}
V_i^c = \sum_n^H \sum_m^W U_{nm}^c \max(0, 1 – |x_i^s – m|) \max(0, 1 – |y_i^s – n|) \tag{5}
\end{equation}

  在反向传播时,需要计算对 $U$ 和 $G$ 的梯度,对于式 $(5)$ 所示的采样核,有

\begin{equation}
\frac{\partial V_i^c}{\partial U_{nm}^c} = \sum_n^H \sum_m^W \max(0, 1 – |x_i^s – m|) \max(0, 1 – |y_i^s – n|) \tag{6}
\end{equation}

\begin{equation}
\frac{\partial V_i^c}{\partial x_i^s} = \sum_n^H \sum_m^W U_{nm}^c \max(0, 1 – |y_i^s – n|)
\begin{cases}0 & \mathrm{if} \; |m – x_i^s| \geq 1 \\
1 & \mathrm{if} \; m \geq x_i^s \\
-1 & \mathrm{if} \; m < x_i^s \end{cases} \tag{7}
\end{equation}

$\frac{\partial V_i^c}{\partial x_i^s}$ 的计算如式 $(7)$。通过 $\frac{\partial V_i^c}{\partial x_i^s}$ 和 $\frac{\partial V_i^c}{\partial x_i^s}$,可以进一步得到关于变换的参数 $\theta$ 的梯度。

2.4. 空间变换网络

  Figure 2 所示的 ST 模块可以插入到 CNN 架构的任意位置,得到空间变换网络(spatial transformer network)。通过引入空间变换,网络可以主动地对特征图进行变换,为训练过程中最小化全局损失函数提供帮助,网络可以在训练过程中学习如何对不同的输入样本进行合适的变换。定位网络的输出 $\theta$ 编码了目标的姿态和区域等信息,对于部分任务,甚至可以将 $\theta$ 直接输入给后续网络。

  通过调整输出尺寸 $H’$ 和 $w’$,可以使用空间变换来对特征图进行上下采样。然而由于采样核通常只作用在固定的一小块区域,进行下采样会导致混叠效应(aliasing effect)。

  在 CNN 中可以使用多个 ST,应用的层数越深,就能对越抽象的特征进行变换。如果特征图中有多个需要关注的目标,还可以并行地使用多个 ST,但这样一来,ST 的数量就限制了网络能识别的对象的数量。

3. 实验结果

  文章在 MNIST 数据集上进行了各种变换,来验证 ST 学习空间不变性的效果,结果如 Table 1 所示。可见使用 ST 提升了网络性能。注意到在 RTS 的变换下,ST-FCN 和 CNN 错误率都为 0.8%,具有相同的性能,而前者并没有使用卷积或池化来保持空间不变性。另外注意到 ST-CNN 的性能始终优于 ST-FCN,前者的卷积和池化层进一步增强了空间不变性。在各 ST 网路中,TPS 均获得了最低的错误率。由 Table 1 右图可见,ST 将数字变换为竖直的姿态,这是训练集中的平均姿态。

Table 1

Table 1

  文章接着在 Street View House Numbers(SVHN)数据集上进行了验证,数据集中包含了 20 万张真实门牌号码的图片,用于识别其中的数字。结果如 Table 2 所示,可见 ST-CNN 具有最佳的性能,在 128px 上的提升尤为明显,且比之前的最佳模型更加简单,只需一次前向传播即可。由 Table 2 右图可见,ST-CNN 可以对数字区域进行裁剪和缩放,使得网络将注意力集中在数字区域。

Table 2

Table 2

  文章还在 CUB-200-2011 鸟类数据集上对 ST 在细粒度分类任务上的性能进行了验证,该数据集包含了 200 种鸟类的 6000 张训练集图片和 5800 张测试集图片。结果如 Table 3 所示,可见加入 ST 的网络具有最高的准确率。由 Table 3 右图展示了使用不同数量的 ST-CNN 时 ST 所预测的变换,注意到红框代表的 ST 学会了检测鸟头,绿框代表的 ST 则主要检测鸟的身子。

Table 3

Table 3

4. 总结

  ST 作为一个独立的模块,可以很容易地插入到已有的网络中,通过对特征图进行空间变换,进一步提升网络的性能。ST 可以端到端地进行训练,不需要修改损失函数。ST 输出的变换参数还可以直接作为输出,用于后续任务。