将神经网络自动化推向高速档

人工智能领域的一个新领域涉及使用算法自动设计称为神经网络的机器学习系统,这种系统比人工程师开发的更加准确和高效。但是这种所谓的神经结构搜索(NAS)技术在计算上是昂贵的。

最近由Google开发的最先进的NAS算法之一由一组图形处理单元(GPU)进行了48,000小时的工作,以生成单个卷积神经网络,用于图像分类和识别任务。谷歌拥有并行运行数百个GPU和其他专用电路的资金,但这对其他许多人来说是遥不可及的。

在5月份的国际学习代表会议上发表的一篇论文中,麻省理工学院的研究人员描述了一种NAS算法,可以直接学习目标硬件平台的专用卷积神经网络(CNN) - 当在大规模图像数据集上运行时 - 仅200 GPU小时数,可以更广泛地使用这些类型的算法。

研究人员表示,资源匮乏的研究人员和公司可以从节省时间和成本的算法中受益。共同作者,电子工程和计算机科学助理教授,麻省理工学院微系统技术实验室研究员宋汉表示,其主要目标是“实现人工智能民主化”。“我们希望通过在特定硬件上快速运行的按钮解决方案,使AI专家和非专家能够有效地设计神经网络架构。”

Han补充说,这种NAS算法永远不会取代人类工程师。“目的是减轻设计和改进神经网络架构所带来的重复性和繁琐的工作,”Han说,他的团队中的两位研究人员Han Cai和Ligeng Zhu参与了论文。

“路径级”二值化和修剪

在他们的工作中,研究人员开发了删除不必要的神经网络设计组件,缩短计算时间并仅使用一小部分硬件内存来运行NAS算法的方法。另一项创新确保每个输出的CNN在特定硬件平台(CPU,GPU和移动设备)上的运行效率高于传统方法设计的平台。在测试中,研究人员的CNN在移动电话上的测量速度比传统的具有相似精度的金标准模型快1.8倍。

CNN的架构由具有可调参数的计算层组成,称为“过滤器”,以及这些过滤器之间可能的连接。过滤处理正方形网格中的图像像素 - 例如3x3,5x5或7x7 - 每个滤镜覆盖一个正方形。滤镜基本上在图像上移动,并将其覆盖的像素网格的所有颜色组合成单个像素。不同的层可能具有不同大小的过滤器,并且以不同的方式连接以共享数据。输出是一个浓缩图像 - 来自所有过滤器的组合信息 - 可以通过计算机更容易地进行分析。

由于可供选择的可能架构的数量 - 称为“搜索空间” - 如此之大,因此应用NAS在海量图像数据集上创建神经网络在计算上是令人望而却步的。工程师通常在较小的代理数据集上运行NAS,并将他们学习的CNN架构转移到目标任务。然而,这种推广方法降低了模型的准确性。此外,相同的输出架构也适用于所有硬件平台,这导致效率问题。

研究人员在ImageNet数据集中的图像分类任务上训练和测试了他们的新NAS算法,该数据集包含数千个类中的数百万个图像。他们首先创建了一个搜索空间,其中包含所有可能的候选CNN“路径” - 表示层和过滤器如何连接以处理数据。这使得NAS算法可以自由地找到最佳架构。

这通常意味着所有可能的路径必须存储在内存中,这将超过GPU内存限制。为了解决这个问题,研究人员利用了一种称为“路径级二值化”的技术,该技术一次只存储一个采样路径,并节省了一个数量级的内存消耗。他们将这种二值化与“路径级修剪”相结合,这种技术传统上可以在不影响输出的情况下学习神经网络中的哪些“神经元”。然而,研究人员的NAS算法不是丢弃神经元,而是修剪整个路径,这完全改变了神经网络的架构。

在训练中,所有路径最初都被赋予相同的选择概率。该算法然后跟踪路径 - 一次只存储一个 - 以记录其输出的准确性和损失(分配给错误预测的数字惩罚)。然后,它调整路径的概率,以优化准确性和效率。最后,该算法修剪掉所有低概率路径并仅保留具有最高概率的路径 - 这是最终的CNN架构。

硬件识别

Han表示,另一项重要创新是使NAS算法“具有硬件感知能力”,这意味着它将每个硬件平台上的延迟用作优化架构的反馈信号。例如,为了衡量移动设备上的这种延迟,Google等大公司将使用移动设备的“农场”,这非常昂贵。研究人员建立了一个模型,只使用一部手机即可预测延迟。

对于网络的每个所选层,该算法对该延迟预测模型上的架构进行采样。然后,它使用该信息来设计尽可能快地运行的架构,同时实现高精度。在实验中,研究人员的CNN在移动设备上的运行速度几乎是金标准模型的两倍。

Han说,一个有趣的结果是,他们的NAS算法设计的CNN架构长期被认为效率太低 - 但是,在研究人员的测试中,它们实际上针对某些硬件进行了优化。例如,工程师基本上停止使用7x7滤波器,因为它们的计算成本比多个较小的滤波器贵。然而,研究人员的NAS算法发现,具有一些7x7滤波器层的架构在GPU上运行得最佳。这是因为GPU具有高并行性 - 意味着它们同时计算许多计算 - 因此可以比一次处理多个小型过滤器更有效地处理单个大型过滤器。