Processing math: 100%

[Reading] Neural Architecture Search with Reinforcement Learning

Neural Architecture Search with Reinforcement Learning (2016/11)

1. 概述

  文章的主要贡献有:

  • 提出了一种自动化搜索神经网络架构的方法,称为神经网络架构搜索(Neural Architecture Search,NAS)。这种方法使用循环神经网络(RNN)生成可变长度网络描述,以此来构建神经网络模型。其中的 RNN 作为控制器,通过强化学习的方式训练,来最大化所生成模型在验证集上准确率的期望。
  • 通过 NAS 分别生成了用于图像识别的卷积神经网络(CNN),以及用于语言建模的 RNN 单元,都获得了很好的效果。证明通过 NAS 可以得到高性能的网络架构,且其性能可以泛化到其他任务。

  文章给出 NAS 的过程如 Figure 1 所示。

Figure 1
Figure 1

神经网络的架构可以用一个可变长度的字符串来描述,因此也就可以使用一个循环网络(控制器)来生成。生成的网络称为子网络(child network),在真实数据上训练后,使用验证集上的准确率作为奖励信号,由此计算策略梯度,来更新控制器。在下个迭代,控制器就会有更高的概率输出具有更高准确率的网络架构。

2. 搜索方法

2.1. RNN 控制器

  文章使用 RNN 作为控制器,来生成网络架构相关的一系列参数,Figure 2 展示了生成 CNN 的情况。控制器逐层生成网络的各参数,当总层数到达一定数量时停止生成,得到一个网络。然后训练这个生成的网络直到收敛,记录在留出的验证集上的准确率,据此更新 RNN 控制器的参数 θc,以最大化生成网络的准确率的期望。此外生成的网络层数也会随着 RNN 控制器的训练而逐渐增加。

Figure 2
Figure 2

2.2. 训练方法

  文章将学习生成神经网络架构看成是一个强化学习的问题,将控制器输出的网络结构参数看成是动作 a1:T,生成的网络经训练收敛后,在验证集上的准确率作为回报 R,训练控制器来最大化期望回报

J(θc)=EP(a1:T;θc)[R]

R 不可微,需要使用策略梯度来迭代更新 θc,文章使用了 REINFORCE 算法

θcJ(θc)=Tt=1EP(a1:T;θc)[θclogP(at|a(t1):1;θc)R]

上式近似为

1mmk=1Tt=1θclogP(at|a(t1):1;θc)Rk

其中 m 是控制器在一个 batch 中生成的不同架构的数量,T 是控制器生成一个架构所需预测的超参数数量,Rk 是生成的第 k 个网络的验证集准确率。为了降低方差,文章使用历史架构准确率的指数滑动平均作为 baseline b,即

1mmk=1Tt=1θclogP(at|a(t1):1;θc)(Rkb)

2.3. 增加架构复杂度

  Figure 2 中所示的结构只能生成逐层堆叠的架构,为了让控制器能够预测类似 ResNet 中的 skip connection,在每层之后添加了一个额外的 anchor point,如 Figure 4 所示。

Figure 4
Figure 4

  第 N 层的 anchor point 其中包含了 N1 个 sigmoid,来预测是否将之前的某层直接连接过来,即

P(Layer j is an input to layer i)=sigmoid(vTtanh(Wprevhj+Wcurrhi))

其中 hj 为第 j 层 anchor point 的隐藏状态,WprevWcurrv 为可训练的参数。

  建立了 skip connection 之后,某一层可能会有多个输入,此时将多个输入在通道维度上拼接起来。但拼接两层的输出具有相同的尺寸,如果尺寸不同,文章使用补零的方法,先对较小的尺寸进行补充,再进行拼接。

  如果网络中的某一层没有与之前的层有任何连接,则这一层作为输入层。在最后一层,将之前所有没有被拼接的输出拼接起来,作为最后分类器的输入。

  上面的例子只是预测了卷积层,通过在控制器中添加额外的步骤来预测层类型,可以引入池化、BatchNorm 等不同类型的层。此外还可以加入学习率等其他超参数的预测。

2.4. 生成 recurrent cell 架构

  文章给出了生成循环网络中 cell 结构的方法。对于循环网络中的 cell,搜索算法需要生成一个形如 ht=f(xt,ht1) 的函数 f,如 ht=tanh(W1xt+W2ht1)。文章将此类函数的结构看成一棵树,树的叶子节点的输入为 xtht1,树的节点代表一个二元函数,根节点输出为 ht,如 Figure 5 左图所示。在 LSTM 中每个 cell 还有额外的输入 ct1 和额外的输出 ct,可以在控制器最后加入两个额外的步骤来预测 ct1ct 和树中的哪个节点相连,如 Figure 5 中图所示。

Figure 5
Figure 5

  Figure 5 中图所示的网络生成步骤为:

  • 在 Tree Index 0 阶段,控制器预测了 Add 和 Tanh,表示计算 a0=tanh(W1x1+W2ht1)
  • 在 Tree Index 1 阶段,控制器预测了 Elem Mult 和 ReLU,表示计算 a1=ReLU((W3x1)(W4ht1))
  • 在 Cell Inject 阶段,控制器预测了 Add 和 ReLU;在 Cell indices 阶段,控制器预测的第 2 个元素为 0,表示将 ct1 连接到 a0,即计算 anew0=ReLU(a0+ct1)
  • 在 Tree Index 2 阶段,控制器预测了 Elem Mult 和 Sigmoid,表示计算 a2=sigmoid(anew0a1),注意 Figure 5 左图树节点最大索引为 2,则 ht=a2
  • 在 Cell indices 阶段,控制器预测的第 1 个元素为 1,表示 ct 为 Tree Index 1 在激活函数前的输出,即 ct=(W3x1)(W4ht1)

由此就构建出了一个完成的 cell 结构。注意 Figure 5 左图的树有两个叶子节点,称为 base 2 架构。文章在实验中使用了 base 8 架构,以获取更强的表达能力。

3. 实验结果

  文章在图像识别也语言建模两个任务上,验证了通过 NAS 所得到的网络的性能。

3.1. 生成 CNN

  文章使用了 CIFAR-10 数据集来生成 CNN 网络并验证其性能。使用了两层 RNN(每层 35 个隐藏单元)作为控制器,搜索空间包括卷积层、ReLU、batch norm、和 skip
connection,分别尝试了固定步长为 1 和预测步长( [1,2,3])两种步长方式。对于生成的网络,从训练集中随机抽取 5000 个样本作为验证集,剩下的 45000 个样本作为训练集,训练 50 个 epoch 后记录验证集准确率,据此更新控制器参数。控制器在初始阶段只生成较浅的网络(6 层),随着训练的进行(每 1600 个子网络样本),不断增加网络深度(深度 ×2)。

  实验一共训练了 12800 个网络,从中选择最佳验证准确率的网络后,通过网格搜索学习率等超参数,然后进行测试,结果如 Table 1 所示。可见 NAS 生成的网络可以达到很好的性能,其中 39 层的模型到达了最佳的 3.65% 错误率,速度是 DenseNet (L = 100; k = 24) 的 1.05 倍。

Table 1
Table 1

3.2. 生成 recurrent cell

  文章使用了 Penn Treebank 数据集来生成 recurrent cell 结构并验证其性能。使用了类似于 3.1 中的 RNN 控制器,搜索的操作包括 [add, elem_mult],搜索的激活函数包括 [identity, tanh, sigmoid, relu],base 为 8。生成 cell 结构后,用其构造一个两层的子网络并训练 35 个 epoch,并使用验证集的困惑度(perplexity)进行评价。

  类似于生成 CNN 的流程,实验选择了最低验证集困惑度(perplexity)的结构,通过网格搜索学习率等超参数,实验结果如 Table 2 所示。可见 NAS 生成的 cell 结构也都获得了很好的性能,最佳模型达到了 62.4 的困惑度,较之前的最佳性能 66 降低了 3.6;其中困惑度为 64 的网络速度是之前最佳网络的 2 倍。

Table 2
Table 2

  为了验证 NAS 生成的 cell 可以泛化到其他任务上,文章将其用于同一数据集的字符语言建模任务,结果如 Table 3 所示,可见新的 cell 结构可以泛化到其他任务,且性能优于 LSTM。

Table 3
Table 3