Deep Learning Note: 3-5 不匹配的训练集和开发/测试集

1. 使用具有不同分布的训练集和测试集

  深度学习算法需要大量的数据,有时候不得不从各种途径收集尽可能多的数据用作训练,导致训练集和开发/测试集具有不同的分布。

  以从用户上传的图片中识别猫的应用为例,可以从两种途径获得训练数据:其一是使用应用用户实际上传猫的图片,这些图片通常分辨率不高,拍摄质量也不佳,这是我们真正关心的数据,即我们希望应用能在这些数据上有较好的性能,但这些数据的数量较少,比如有 1 万个样本;其二是从网上收集猫的图片,在网上排名较高的图片往往具有很高分辨率,由专业摄影师拍摄,质量也很高,且数量庞大,比如收集到了 20 万个样本。现在的问题是,要如何使用这些数据。

  一种做法是将从这两种途径获得的数据随机地混合在一起,得到 21 万个样本,划分为训练集(205000 个样本)、开发集和测试集(各 2500 个样本)。这种做法虽然有训练集、开发集和测试集都具有相同的分布的优点,但一个巨大的缺点是,现在开发集中有很大一部分(20万/21万)数据是从网络上收集的,而我们真正关心的是应用用户实际上传的图片。这样的开发集不能反映一个算法的真实性能,即通过开发集选择的是在网络图片上具有最好的性能算法,而不是在用户上传的图片上具有最好性能的算法。因此不推荐这种方法。

  另一种做法是使用全部 20 万个网络图片,再加上 5000 个用户上传图片作为训练集(205000 个样本),使用剩下的 5000 个用户上传的图片作为开发集和测试集(各 2500 个样本)。现在开发集中包含两个来源的数据,而开发集和测试集的数据都来自用户上传图片。这种做法的优点是开发集的分布与我们实际关心的应用场景相同,通过开发集可以选出在实际场景中具有最佳性能的算法。而这种做法的缺点是,现在训练集具有与开发/测试集不同的分布,但在长期上这种做法的效果更好。

  接受具有不同分布的训练集能够大大增加可以使用的数据量,从而提高算法性能,但这样做有时也会带来问题。

2. 不同数据分布下的偏差和方差分析

  通过偏差和方差分析,我们可以定位算法中存在的问题,从而找到进一步提升算法性能的方向。当训练集与开发/测试集具有不同的分布时,偏差和方差分析的方法会有所变化。

  还是以用户上传的图片中识别猫的分类器为例,假设有:

  • 人类错误率:接近 0%
  • 训练集错误率:1%
  • 开发集错误率:10%

  从上面的数据可以看出,开发集错误率和训练集错误率之间的差距较大,根据前文的分析,可以判断该分类器具有高方差问题。以上结论是在训练集和开发/测试集具有相同分布的情况下得到的,当训练集和开发/测试集具有不同的分布时,就不能这样判断。例如使用高质量的网络图片作为训练集,使用用户上传的图片作为开发集,此时训练集的数据质量很高,算法容易得到较好的性能,训练集错误率很低;而开发集数据质量低,对算法来说更加困难,得到的开发集错误率较高。在这种情况下,我们不知道 10% 的开发集错误率中,有多少是因为数据更困难引起的,而又有多少是因为高方差问题引起的,此时算法并不一定有高方差问题,10% 的开发集错误率可能是可以接受的。

  为了厘清训练集和开发集具有不同分布时,开发集错误率的来源,可以建立一个训练-开发集(Training-Dev Set),其中的数据具有和训练集相同的分布,但没有用于训练。即随机从训练集中抽取一小部分数据,不用做训练,作为训练-开发集。就像开发集和测试集具有相同的分布,训练集和训练-开发集具有相同的分布。假设有:

  • 人类错误率:接近 0%
  • 训练集错误率:1%
  • 训练-开发集错误率:9%
  • 开发集错误率:10%

  从上面的数据可以看出,训练-开发集错误率和训练集错误率之间的差距较大。训练-开发集和训练集之间的区别是,没有使用训练-开发集进行训练。算法在训练集上的性能不能很好地泛化到训练-开发集上,可以判断算法存在高方差问题。

  第二个例子,假设有:

  • 人类错误率:接近 0%
  • 训练集错误率:1%
  • 训练-开发集错误率:1.5%
  • 开发集错误率:10%

  从上面的数据可以看出,训练-开发集错误率和训练集错误率之间的差距很小,说明方差问题很小;而开发集错误率和训练-开发集错误率之间的差距很大,开发集和训练-开发集之间的区别是,二者的分布不同,由此可以判断大部分的错误是因为数据分布不同引起的,存在数据不匹配(Data Mismatch)的问题。

  第三个例子,假设有:

  • 人类错误率:接近 0%
  • 训练集错误率:10%
  • 训练-开发集错误率:11%
  • 开发集错误率:12%

  此时训练集错误率和人类错误率之间有很大差距,说明算法有很大的可避免的偏差,存在高偏差问题。

  最后一个例子,假设有:

  • 人类错误率:接近 0%
  • 训练集错误率:10%
  • 训练-开发集错误率:11%
  • 开发集错误率:20%

  此时训练集错误率和人类错误率之间有很大差距,存在高偏差问题;训练-开发集错误率和训练集错误率之间相差不大,方差问题很小;开发集错误率和训练-开发集错误率之间有很大差距,说明存在数据不匹配的问题。

  综上所述,对于具有不同分布的训练集和开发/测试集,分析偏差和方差时要考察以下指标:

  • 人类水平
  • 训练集错误
  • 开发-训练集错误
  • 开发集错误
  • (测试集错误)

  训练集错误和人类水平之间的差距表示可避免的偏差,开发-训练集错误和训练集错误之间的差距表示方差,开发集错误和开发-训练集错误之间的差距表示由数据不匹配带来的错误,测试集错误和开发集错误之间的差距表示了过拟合于开发集的程度。

  通常情况下,以上五个指标会按照从上到下的顺序依次递增,但有时候也会出现特例,如:

  • 人类错误率:4%
  • 训练集错误:7%
  • 开发-训练集错误:10%
  • 开发集错误:6%
  • (测试集错误):6%

  此时开发/测试集错误低于训练集错误率和开发-训练集错误率,其原因可能是训练集和开发-训练集的数据要比开发集中的数据更难。

  例如对于某个通过语音识别查询导航路线的应用,对着应用说出导航目的地(如“我要去西直门”),应用就会识别查询并给出路线。我们能收集到大量的通用语音识别的数据,用作训练集和开发-测试集;另外收集到少量导航查询语音识别的数据,用作开发集和测试集。此时由前面的数据可以得到如表 1 所示的表格:

表 1

通用语音识别数据 导航查询语音识别数据
人类水平 人类错误率:4%
训练过的数据上的错误率 训练集错误:7%
未训练过的数据上的错误率 训练-开发集错误:10% 开发/测试集错误:6%

  虽然有时候使用表 1 的数据就足以分析问题,但通过把表 1 中空白的两格填上,我们可以对问题有更多的洞察。例如通过邀请他人手动标记导航查询语音识别数据,以及使用导航查询语音识别数据进行训练,可以得到导航查询语音识别数据上的人类水平和训练过的数据上的错误率,假设均为 6%,如表 2 所示:

表 2

通用语音识别数据 导航查询语音识别数据
人类水平 人类错误率:4% 6%
训练过的数据上的错误率 训练集错误:7% 6%
未训练过的数据上的错误率 训练-开发集错误:10% 开发/测试集错误:6%

  由表 2 可见,虽然在通用语音识别数据上,该算法的性能与人类错误率(4%)还有差距,但在导航查询语音识别数据上,算法性能已经达到人类水平(6%),算法在导航查询语音识别问题上的性能已经很好了。通过比较人类水平一栏,可以知道对于人类来说,导航查询语音识别比通用语音识别更难。

3. 解决数据不匹配

  使用具有不同分布的数据作为训练集可以极大地扩充训练集的容量,但此时除了潜在的偏差和方差问题,还可能会引入的数据不匹配的问题,数据不匹配可能会带来很高的错误率。对于数据不匹配的问题,虽然目前并没有非常系统的解决办法,但仍可以尝试以下方法:

  通过人工错误分析,尝试理解训练集和开发/测试集之间的区别。例如前面通过语音识别查询导航路线的应用,使用通用语音识别数据作为训练集,使用导航查询语音识别数据作为开发集,通过检查发集和训练集中的样本,可能会发现由于开发集的数据来自实际场景,样本音频的背景中有很大噪声,如汽车引擎等声音,导致对开发集中样本的识别更加困难。有时还会发现其它的一些有用信息,比如发现开发集中道路名称出现得较为频繁,算法能否准确识别道路名称可能会对其性能有较大影响,需要重视对道路名称的学习。

  在理解了训练集和开发/测试集之间的差异之后,可以让训练集中的数据更加一致,或收集更多与开发/测试集相似的数据。例如发现大量的错误来自背景中的汽车噪音,那么可以通过模拟噪音生成更多带噪音的数据;如果发现算法难以正确识别道路名称,那么可以去收集更多带道路名称的数据,加入到训练集。
  
  如果希望训练集中的数据更加一致,可以通过人工数据合成(Artificial Data Synthesis)的方法,生成更多数据。例如我们想要获取带汽车噪声的语音数据用作训练,但这种特定场合的数据并不容易直接获取。通过收集安静背景下的语音识别数据和汽车噪音,将二者混合起来,就可以获得带汽车噪声的语音数据。

  在进行人工数据合成时需要注意的一点是,如果有 10000 小时的安静背景下的语音识别数据和 1 小时的汽车噪音,进行合成的一种方式是将 1 小时的汽车噪音重复 10000 遍,得到 10000 小时的汽车噪音,与 10000 小时的语音数据混合。这样做的结果在人听来没什么问题,但对算法来说,存在着过拟合于这 1 小时汽车噪音的风险,因为这 1 小时的汽车噪音可能只包含众多汽车噪音场景中的一小部分。如果能够收集 10000 小时的汽车噪音,并与 10000 小时的语音数据混合,可能会得到更好的性能。以上问题也是人工数据合成的一个挑战,对人类来说,1 小时的汽车噪音与 10000 小时的汽车噪音没什么区别,如果仅凭直觉使用 1 小时的汽车噪音进行合成,得到的数据可能就无法反应真实场景的多样性。

  再举一例,假设要训练一个从图片中识别汽车的应用,很多人都会建议通过 CG(Computer Graphics)生成汽车的图像,现代 CG 技术能够轻松地生成以假乱真的汽车图像,这些图像在人看来没有什么问题,但生成的汽车和场景的集合可能只是真实汽车和场景集合的一小部分,算法仍存在着过拟合于生成的数据集的风险。例如从竞速类电子游戏中获取汽车图片,游戏中有 20 种车型,这个游戏在人看来是十分丰富和真实的,但现实世界中的汽车类型远远多于 20 种,如果训练集中只有 20 种车,那么算法就可能过拟合于这 20 种车,而仅凭人的直觉很难发现这一问题。

  综上所述,如果发现存在数据不匹配的问题,可以通过错误分析找出训练集和开发/测试集的差别,然后尝试收集更多与开发/测试集类似的数据用于训练。获取数据的一种有效方法是使用人工数据合成,但要考虑周全,不要只模拟了现实情况的一小部分。