3 个月前

神经架构迁移

神经架构迁移

摘要

神经架构搜索(Neural Architecture Search, NAS)已成为自动设计任务特定神经网络的有前景方向。现有NAS方法针对每一种硬件配置或优化目标均需完成一次完整的搜索过程,而由于应用场景可能极为多样,这种做法在计算上极不现实。本文提出神经架构迁移(Neural Architecture Transfer, NAT),以克服这一局限。NAT旨在高效生成在多种相互冲突目标下均具备竞争力的任务特定定制模型。为实现该目标,我们从任务特定的超网络(supernet)中学习,从而可在无需额外训练的情况下采样出专用子网络。本方法的核心在于一种融合在线迁移学习与多目标进化搜索的集成机制:预训练的超网络在持续迭代中被适应,同时同步进行针对特定任务的子网络搜索。我们在11个基准图像分类任务上验证了NAT的有效性,涵盖从大规模多类别到小规模细粒度数据集的多种场景。在所有测试案例中,包括ImageNet在内,NAT生成的模型在移动端设置下(≤600M乘加操作)均优于现有最先进方法。令人意外的是,小规模细粒度数据集在NAT框架下受益最为显著。与此同时,架构搜索与迁移过程的效率相比现有NAS方法提升了数个数量级。总体而言,实验结果表明,在多样化的图像分类任务与计算约束条件下,NAT相较于传统迁移学习(即在标准数据集上预训练的网络架构进行微调权重)是一种显著更高效的替代方案。代码已开源,地址为:https://github.com/human-analysis/neural-architecture-transfer

代码仓库

awesomelemon/encas
pytorch
GitHub 中提及
human-analysis/neural-architecture-transfer
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
architecture-search-on-cifar-10-imageNAT-M2
FLOPS: 291M
Params: 4.6M
Percentage error: 2.1
architecture-search-on-cifar-10-imageNAT-M1
FLOPS: 232M
Params: 4.3M
Percentage error: 2.6
architecture-search-on-cifar-10-imageNAT-M3
FLOPS: 392M
Params: 6.2M
Percentage error: 1.8
architecture-search-on-cifar-10-imageNAT-M4
FLOPS: 468M
Params: 6.9M
Percentage error: 1.6
fine-grained-image-classification-on-fgvcNAT-M2
Accuracy: 89.0%
FLOPS: 235M
PARAMS: 3.4M
fine-grained-image-classification-on-fgvcNAT-M3
Accuracy: 90.1%
FLOPS: 388M
PARAMS: 5.1M
fine-grained-image-classification-on-fgvcNAT-M1
Accuracy: 87.0%
FLOPS: 175M
PARAMS: 3.2M
fine-grained-image-classification-on-fgvcNAT-M4
Accuracy: 90.8%
FLOPS: 581M
PARAMS: 5.3M
fine-grained-image-classification-on-food-101NAT-M4
Accuracy: 89.4
FLOPS: 361M
PARAMS: 4.5M
fine-grained-image-classification-on-food-101NAT-M1
Accuracy: 87.4
FLOPS: 198M
PARAMS: 3.1M
fine-grained-image-classification-on-food-101NAT-M2
Accuracy: 88.5
FLOPS: 266M
PARAMS: 4.1M
fine-grained-image-classification-on-food-101NAT-M3
Accuracy: 89.0
FLOPS: 299M
PARAMS: 3.9M
fine-grained-image-classification-on-oxfordNAT-M3
Accuracy: 98.1
FLOPS: 250M
PARAMS: 3.7M
fine-grained-image-classification-on-oxfordNAT-M1
FLOPS: 152M
PARAMS: 3.3M
fine-grained-image-classification-on-oxfordNAT-M2
Accuracy: 97.9
FLOPS: 195M
PARAMS: 3.4M
fine-grained-image-classification-on-oxfordNAT-M4
Accuracy: 98.3
FLOPS: 400M
PARAMS: 4.2M
fine-grained-image-classification-on-oxford-1NAT-M1
FLOPS: 160M
PARAMS: 4.0M
fine-grained-image-classification-on-oxford-2NAT-M3
Accuracy: 94.1
FLOPS: 471M
PARAMS: 5.7M
Top-1 Error Rate: 5.9%
fine-grained-image-classification-on-oxford-2NAT-M2
Accuracy: 93.5
FLOPS: 306M
PARAMS: 5.5M
Top-1 Error Rate: 6.5%
fine-grained-image-classification-on-oxford-2NAT-M4
Accuracy: 94.3
FLOPS: 744M
PARAMS: 8.5M
Top-1 Error Rate: 5.7%
fine-grained-image-classification-on-stanfordNAT-M4
Accuracy: 92.9%
FLOPS: 369M
PARAMS: 3.7M
fine-grained-image-classification-on-stanfordNAT-M2
Accuracy: 92.2%
FLOPS: 222M
PARAMS: 2.7M
fine-grained-image-classification-on-stanfordNAT-M1
Accuracy: 90.9%
FLOPS: 165M
PARAMS: 2.4M
fine-grained-image-classification-on-stanfordNAT-M3
Accuracy: 92.6%
FLOPS: 289M
PARAMS: 3.5M
image-classification-on-cifar-10NAT-M3
Parameters: 6.2M
Percentage correct: 98.2
Top-1 Accuracy: 98.2
image-classification-on-cifar-10NAT-M4
Parameters: 6.9M
Percentage correct: 98.4
Top-1 Accuracy: 98.4
image-classification-on-cifar-10NAT-M2
Parameters: 4.6M
Percentage correct: 97.9
Top-1 Accuracy: 97.9
image-classification-on-cifar-10NAT-M1
Parameters: 4.3M
Percentage correct: 97.4
Top-1 Accuracy: 97.4
image-classification-on-cifar-100NAT-M1
PARAMS: 3.8M
Percentage correct: 86.0
image-classification-on-cifar-100NAT-M3
PARAMS: 7.8M
Percentage correct: 87.7
image-classification-on-cifar-100NAT-M4
PARAMS: 9.0M
Percentage correct: 88.3
image-classification-on-cifar-100NAT-M2
PARAMS: 6.4M
Percentage correct: 87.5
image-classification-on-cinic-10NAT-M2
Accuracy: 94.1
FLOPS: 411M
PARAMS: 6.2M
image-classification-on-cinic-10NAT-M1
Accuracy: 93.4
FLOPS: 317M
PARAMS: 4.6M
image-classification-on-cinic-10NAT-M3
Accuracy: 94.3
FLOPS: 501M
PARAMS: 8.1M
image-classification-on-flowers-102NAT-M1
FLOPS: 152M
PARAMS: 3.3M
image-classification-on-flowers-102NAT-M3
Accuracy: 98.1%
FLOPS: 250M
PARAMS: 3.7M
image-classification-on-flowers-102NAT-M4
Accuracy: 98.3%
FLOPS: 400M
PARAMS: 4.2M
image-classification-on-flowers-102NAT-M2
Accuracy: 97.9%
FLOPS: 195M
PARAMS: 3.4M
image-classification-on-imagenetNAT-M4
Number of parameters (M): 9.1M
Number of params: 9.1M
Top 1 Accuracy: 80.5%
image-classification-on-stl-10NAT-M3
FLOPS: 436M
PARAMS: 7.5M
Percentage correct: 97.8
image-classification-on-stl-10NAT-M1
FLOPS: 240M
PARAMS: 4.4M
Percentage correct: 96.7
image-classification-on-stl-10NAT-M4
FLOPS: 573M
PARAMS: 7.5M
Percentage correct: 97.9
image-classification-on-stl-10NAT-M2
FLOPS: 303M
PARAMS: 5.1M
Percentage correct: 97.2
neural-architecture-search-on-cifar-10NAT-M1
FLOPS: 232M
Parameters: 4.3M
Search Time (GPU days): 1.0
Top-1 Error Rate: 2.6%
neural-architecture-search-on-cifar-10NAT-M4
FLOPS: 468M
Parameters: 6.9M
Search Time (GPU days): 1.0
Top-1 Error Rate: 1.6%
neural-architecture-search-on-cifar-10NAT-M2
FLOPS: 291M
Parameters: 4.6M
Search Time (GPU days): 1.0
Top-1 Error Rate: 2.1%
neural-architecture-search-on-cifar-10NAT-M3
FLOPS: 392M
Parameters: 6.2M
Search Time (GPU days): 1.0
Top-1 Error Rate: 1.8%
neural-architecture-search-on-cifar-100-1NAT-M1
FLOPS: 261M
PARAMS: 3.8M
Percentage Error: 14.0
neural-architecture-search-on-cifar-100-1NAT-M2
FLOPS: 398M
PARAMS: 6.4M
Percentage Error: 12.5
neural-architecture-search-on-cifar-100-1NAT-M3
FLOPS: 492M
PARAMS: 7.8M
Percentage Error: 12.3
neural-architecture-search-on-cifar-100-1NAT-M4
FLOPS: 796M
PARAMS: 9.0M
Percentage Error: 11.7
neural-architecture-search-on-cinic-10NAT-M2
Accuracy (%): 94.1
FLOPS: 411M
PARAMS: 6.2M
neural-architecture-search-on-cinic-10NAT-M3
Accuracy (%): 94.3
FLOPS: 501M
PARAMS: 8.1M
neural-architecture-search-on-cinic-10NAT-M1
Accuracy (%): 93.4
FLOPS: 317M
PARAMS: 4.6M
neural-architecture-search-on-cinic-10NAT-M4
Accuracy (%): 94.8
FLOPS: 710M
PARAMS: 9.1M
neural-architecture-search-on-dtdNAT-M4
Accuracy (%): 79.1
FLOPS: 560M
PARAMS: 6.3M
neural-architecture-search-on-dtdNAT-M1
Accuracy (%): 76.1
FLOPS: 136M
PARAMS: 2.2M
neural-architecture-search-on-dtdNAT-M2
Accuracy (%): 77.6
FLOPS: 297M
PARAMS: 4.0M
neural-architecture-search-on-dtdNAT-M3
Accuracy (%): 78.4
FLOPS: 347M
PARAMS: 4.1M
neural-architecture-search-on-fgvc-aircraftNAT-M1
Accuracy (%): 87.0
FLOPS: 175M
PARAMS: 3.2M
neural-architecture-search-on-fgvc-aircraftNAT-M2
Accuracy (%): 89.0
FLOPS: 235M
PARAMS: 3.4M
neural-architecture-search-on-fgvc-aircraftNAT-M3
Accuracy (%): 90.1
FLOPS: 388M
PARAMS: 5.1M
neural-architecture-search-on-fgvc-aircraftNAT-M4
Accuracy (%): 90.8
FLOPS: 581M
PARAMS: 5.3M
neural-architecture-search-on-food-101NAT-M2
Accuracy (%): 88.5
FLOPS: 266M
PARAMS: 4.1M
neural-architecture-search-on-food-101NAT-M1
Accuracy (%): 87.4
FLOPS: 198M
PARAMS: 3.1M
neural-architecture-search-on-food-101NAT-M4
Accuracy (%): 89.4
FLOPS: 361M
PARAMS: 4.5M
neural-architecture-search-on-food-101NAT-M3
Accuracy (%): 89.0
FLOPS: 299M
PARAMS: 3.9M
neural-architecture-search-on-imagenetNAT-M4
Accuracy: 80.5
MACs: 600M
Params: 9.1M
Top-1 Error Rate: 19.5
neural-architecture-search-on-imagenetNAT-M3
Accuracy: 79.9
MACs: 490M
Params: 9.1M
Top-1 Error Rate: 20.1
neural-architecture-search-on-imagenetNAT-M1
Accuracy: 77.5
MACs: 225M
Params: 6.0M
Top-1 Error Rate: 22.5
neural-architecture-search-on-imagenetNAT-M2
Accuracy: 78.6
MACs: 312M
Params: 7.7M
Top-1 Error Rate: 21.4
neural-architecture-search-on-oxford-102NAT-M1
Accuracy (%): 97.5
FLOPS: 152M
PARAMS: 3.3M
neural-architecture-search-on-oxford-102NAT-M2
Accuracy (%): 97.9
FLOPS: 195M
PARAMS: 3.4M
neural-architecture-search-on-oxford-102NAT-M3
Accuracy (%): 98.1
FLOPS: 250M
PARAMS: 3.7M
neural-architecture-search-on-oxford-102NAT-M4
Accuracy (%): 98.3
FLOPS: 400M
PARAMS: 4.2M
neural-architecture-search-on-oxford-iiitNAT-M2
Accuracy (%): 93.5
FLOPS: 306M
PARAMS: 5.5M
neural-architecture-search-on-oxford-iiitNAT-M4
Accuracy (%): 94.3
FLOPS: 744M
PARAMS: 8.5M
neural-architecture-search-on-oxford-iiitNAT-M3
Accuracy (%): 94.1
FLOPS: 471M
PARAMS: 5.7M
neural-architecture-search-on-oxford-iiitNAT-M1
Accuracy (%): 91.8
FLOPS: 160M
PARAMS: 4.0M
neural-architecture-search-on-stanford-carsNAT-M3
Accuracy (%): 92.6
FLOPS: 289M
PARAMS: 3.5M
neural-architecture-search-on-stanford-carsNAT-M4
Accuracy (%): 92.9
FLOPS: 369M
PARAMS: 3.7M
neural-architecture-search-on-stanford-carsNAT-M1
Accuracy (%): 90.0
FLOPS: 165M
PARAMS: 2.4M
neural-architecture-search-on-stanford-carsNAT-M2
Accuracy (%): 92.2
FLOPS: 222M
PARAMS: 2.7M
neural-architecture-search-on-stl-10NAT-M1
Accuracy (%): 96.7
FLOPS: 240M
PARAMS: 4.4M
neural-architecture-search-on-stl-10NAT-M4
Accuracy (%): 97.9
FLOPS: 573M
PARAMS: 7.5M
neural-architecture-search-on-stl-10NAT-M2
Accuracy (%): 97.2
FLOPS: 303M
PARAMS: 5.1M
neural-architecture-search-on-stl-10NAT-M3
Accuracy (%): 97.8
FLOPS: 436M
PARAMS: 7.5M

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
神经架构迁移 | 论文 | HyperAI超神经