
摘要
神经架构搜索(Neural Architecture Search, NAS)已成为自动设计任务特定神经网络的有前景方向。现有NAS方法针对每一种硬件配置或优化目标均需完成一次完整的搜索过程,而由于应用场景可能极为多样,这种做法在计算上极不现实。本文提出神经架构迁移(Neural Architecture Transfer, NAT),以克服这一局限。NAT旨在高效生成在多种相互冲突目标下均具备竞争力的任务特定定制模型。为实现该目标,我们从任务特定的超网络(supernet)中学习,从而可在无需额外训练的情况下采样出专用子网络。本方法的核心在于一种融合在线迁移学习与多目标进化搜索的集成机制:预训练的超网络在持续迭代中被适应,同时同步进行针对特定任务的子网络搜索。我们在11个基准图像分类任务上验证了NAT的有效性,涵盖从大规模多类别到小规模细粒度数据集的多种场景。在所有测试案例中,包括ImageNet在内,NAT生成的模型在移动端设置下(≤600M乘加操作)均优于现有最先进方法。令人意外的是,小规模细粒度数据集在NAT框架下受益最为显著。与此同时,架构搜索与迁移过程的效率相比现有NAS方法提升了数个数量级。总体而言,实验结果表明,在多样化的图像分类任务与计算约束条件下,NAT相较于传统迁移学习(即在标准数据集上预训练的网络架构进行微调权重)是一种显著更高效的替代方案。代码已开源,地址为:https://github.com/human-analysis/neural-architecture-transfer
代码仓库
awesomelemon/encas
pytorch
GitHub 中提及
human-analysis/neural-architecture-transfer
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| architecture-search-on-cifar-10-image | NAT-M2 | FLOPS: 291M Params: 4.6M Percentage error: 2.1 |
| architecture-search-on-cifar-10-image | NAT-M1 | FLOPS: 232M Params: 4.3M Percentage error: 2.6 |
| architecture-search-on-cifar-10-image | NAT-M3 | FLOPS: 392M Params: 6.2M Percentage error: 1.8 |
| architecture-search-on-cifar-10-image | NAT-M4 | FLOPS: 468M Params: 6.9M Percentage error: 1.6 |
| fine-grained-image-classification-on-fgvc | NAT-M2 | Accuracy: 89.0% FLOPS: 235M PARAMS: 3.4M |
| fine-grained-image-classification-on-fgvc | NAT-M3 | Accuracy: 90.1% FLOPS: 388M PARAMS: 5.1M |
| fine-grained-image-classification-on-fgvc | NAT-M1 | Accuracy: 87.0% FLOPS: 175M PARAMS: 3.2M |
| fine-grained-image-classification-on-fgvc | NAT-M4 | Accuracy: 90.8% FLOPS: 581M PARAMS: 5.3M |
| fine-grained-image-classification-on-food-101 | NAT-M4 | Accuracy: 89.4 FLOPS: 361M PARAMS: 4.5M |
| fine-grained-image-classification-on-food-101 | NAT-M1 | Accuracy: 87.4 FLOPS: 198M PARAMS: 3.1M |
| fine-grained-image-classification-on-food-101 | NAT-M2 | Accuracy: 88.5 FLOPS: 266M PARAMS: 4.1M |
| fine-grained-image-classification-on-food-101 | NAT-M3 | Accuracy: 89.0 FLOPS: 299M PARAMS: 3.9M |
| fine-grained-image-classification-on-oxford | NAT-M3 | Accuracy: 98.1 FLOPS: 250M PARAMS: 3.7M |
| fine-grained-image-classification-on-oxford | NAT-M1 | FLOPS: 152M PARAMS: 3.3M |
| fine-grained-image-classification-on-oxford | NAT-M2 | Accuracy: 97.9 FLOPS: 195M PARAMS: 3.4M |
| fine-grained-image-classification-on-oxford | NAT-M4 | Accuracy: 98.3 FLOPS: 400M PARAMS: 4.2M |
| fine-grained-image-classification-on-oxford-1 | NAT-M1 | FLOPS: 160M PARAMS: 4.0M |
| fine-grained-image-classification-on-oxford-2 | NAT-M3 | Accuracy: 94.1 FLOPS: 471M PARAMS: 5.7M Top-1 Error Rate: 5.9% |
| fine-grained-image-classification-on-oxford-2 | NAT-M2 | Accuracy: 93.5 FLOPS: 306M PARAMS: 5.5M Top-1 Error Rate: 6.5% |
| fine-grained-image-classification-on-oxford-2 | NAT-M4 | Accuracy: 94.3 FLOPS: 744M PARAMS: 8.5M Top-1 Error Rate: 5.7% |
| fine-grained-image-classification-on-stanford | NAT-M4 | Accuracy: 92.9% FLOPS: 369M PARAMS: 3.7M |
| fine-grained-image-classification-on-stanford | NAT-M2 | Accuracy: 92.2% FLOPS: 222M PARAMS: 2.7M |
| fine-grained-image-classification-on-stanford | NAT-M1 | Accuracy: 90.9% FLOPS: 165M PARAMS: 2.4M |
| fine-grained-image-classification-on-stanford | NAT-M3 | Accuracy: 92.6% FLOPS: 289M PARAMS: 3.5M |
| image-classification-on-cifar-10 | NAT-M3 | Parameters: 6.2M Percentage correct: 98.2 Top-1 Accuracy: 98.2 |
| image-classification-on-cifar-10 | NAT-M4 | Parameters: 6.9M Percentage correct: 98.4 Top-1 Accuracy: 98.4 |
| image-classification-on-cifar-10 | NAT-M2 | Parameters: 4.6M Percentage correct: 97.9 Top-1 Accuracy: 97.9 |
| image-classification-on-cifar-10 | NAT-M1 | Parameters: 4.3M Percentage correct: 97.4 Top-1 Accuracy: 97.4 |
| image-classification-on-cifar-100 | NAT-M1 | PARAMS: 3.8M Percentage correct: 86.0 |
| image-classification-on-cifar-100 | NAT-M3 | PARAMS: 7.8M Percentage correct: 87.7 |
| image-classification-on-cifar-100 | NAT-M4 | PARAMS: 9.0M Percentage correct: 88.3 |
| image-classification-on-cifar-100 | NAT-M2 | PARAMS: 6.4M Percentage correct: 87.5 |
| image-classification-on-cinic-10 | NAT-M2 | Accuracy: 94.1 FLOPS: 411M PARAMS: 6.2M |
| image-classification-on-cinic-10 | NAT-M1 | Accuracy: 93.4 FLOPS: 317M PARAMS: 4.6M |
| image-classification-on-cinic-10 | NAT-M3 | Accuracy: 94.3 FLOPS: 501M PARAMS: 8.1M |
| image-classification-on-flowers-102 | NAT-M1 | FLOPS: 152M PARAMS: 3.3M |
| image-classification-on-flowers-102 | NAT-M3 | Accuracy: 98.1% FLOPS: 250M PARAMS: 3.7M |
| image-classification-on-flowers-102 | NAT-M4 | Accuracy: 98.3% FLOPS: 400M PARAMS: 4.2M |
| image-classification-on-flowers-102 | NAT-M2 | Accuracy: 97.9% FLOPS: 195M PARAMS: 3.4M |
| image-classification-on-imagenet | NAT-M4 | Number of parameters (M): 9.1M Number of params: 9.1M Top 1 Accuracy: 80.5% |
| image-classification-on-stl-10 | NAT-M3 | FLOPS: 436M PARAMS: 7.5M Percentage correct: 97.8 |
| image-classification-on-stl-10 | NAT-M1 | FLOPS: 240M PARAMS: 4.4M Percentage correct: 96.7 |
| image-classification-on-stl-10 | NAT-M4 | FLOPS: 573M PARAMS: 7.5M Percentage correct: 97.9 |
| image-classification-on-stl-10 | NAT-M2 | FLOPS: 303M PARAMS: 5.1M Percentage correct: 97.2 |
| neural-architecture-search-on-cifar-10 | NAT-M1 | FLOPS: 232M Parameters: 4.3M Search Time (GPU days): 1.0 Top-1 Error Rate: 2.6% |
| neural-architecture-search-on-cifar-10 | NAT-M4 | FLOPS: 468M Parameters: 6.9M Search Time (GPU days): 1.0 Top-1 Error Rate: 1.6% |
| neural-architecture-search-on-cifar-10 | NAT-M2 | FLOPS: 291M Parameters: 4.6M Search Time (GPU days): 1.0 Top-1 Error Rate: 2.1% |
| neural-architecture-search-on-cifar-10 | NAT-M3 | FLOPS: 392M Parameters: 6.2M Search Time (GPU days): 1.0 Top-1 Error Rate: 1.8% |
| neural-architecture-search-on-cifar-100-1 | NAT-M1 | FLOPS: 261M PARAMS: 3.8M Percentage Error: 14.0 |
| neural-architecture-search-on-cifar-100-1 | NAT-M2 | FLOPS: 398M PARAMS: 6.4M Percentage Error: 12.5 |
| neural-architecture-search-on-cifar-100-1 | NAT-M3 | FLOPS: 492M PARAMS: 7.8M Percentage Error: 12.3 |
| neural-architecture-search-on-cifar-100-1 | NAT-M4 | FLOPS: 796M PARAMS: 9.0M Percentage Error: 11.7 |
| neural-architecture-search-on-cinic-10 | NAT-M2 | Accuracy (%): 94.1 FLOPS: 411M PARAMS: 6.2M |
| neural-architecture-search-on-cinic-10 | NAT-M3 | Accuracy (%): 94.3 FLOPS: 501M PARAMS: 8.1M |
| neural-architecture-search-on-cinic-10 | NAT-M1 | Accuracy (%): 93.4 FLOPS: 317M PARAMS: 4.6M |
| neural-architecture-search-on-cinic-10 | NAT-M4 | Accuracy (%): 94.8 FLOPS: 710M PARAMS: 9.1M |
| neural-architecture-search-on-dtd | NAT-M4 | Accuracy (%): 79.1 FLOPS: 560M PARAMS: 6.3M |
| neural-architecture-search-on-dtd | NAT-M1 | Accuracy (%): 76.1 FLOPS: 136M PARAMS: 2.2M |
| neural-architecture-search-on-dtd | NAT-M2 | Accuracy (%): 77.6 FLOPS: 297M PARAMS: 4.0M |
| neural-architecture-search-on-dtd | NAT-M3 | Accuracy (%): 78.4 FLOPS: 347M PARAMS: 4.1M |
| neural-architecture-search-on-fgvc-aircraft | NAT-M1 | Accuracy (%): 87.0 FLOPS: 175M PARAMS: 3.2M |
| neural-architecture-search-on-fgvc-aircraft | NAT-M2 | Accuracy (%): 89.0 FLOPS: 235M PARAMS: 3.4M |
| neural-architecture-search-on-fgvc-aircraft | NAT-M3 | Accuracy (%): 90.1 FLOPS: 388M PARAMS: 5.1M |
| neural-architecture-search-on-fgvc-aircraft | NAT-M4 | Accuracy (%): 90.8 FLOPS: 581M PARAMS: 5.3M |
| neural-architecture-search-on-food-101 | NAT-M2 | Accuracy (%): 88.5 FLOPS: 266M PARAMS: 4.1M |
| neural-architecture-search-on-food-101 | NAT-M1 | Accuracy (%): 87.4 FLOPS: 198M PARAMS: 3.1M |
| neural-architecture-search-on-food-101 | NAT-M4 | Accuracy (%): 89.4 FLOPS: 361M PARAMS: 4.5M |
| neural-architecture-search-on-food-101 | NAT-M3 | Accuracy (%): 89.0 FLOPS: 299M PARAMS: 3.9M |
| neural-architecture-search-on-imagenet | NAT-M4 | Accuracy: 80.5 MACs: 600M Params: 9.1M Top-1 Error Rate: 19.5 |
| neural-architecture-search-on-imagenet | NAT-M3 | Accuracy: 79.9 MACs: 490M Params: 9.1M Top-1 Error Rate: 20.1 |
| neural-architecture-search-on-imagenet | NAT-M1 | Accuracy: 77.5 MACs: 225M Params: 6.0M Top-1 Error Rate: 22.5 |
| neural-architecture-search-on-imagenet | NAT-M2 | Accuracy: 78.6 MACs: 312M Params: 7.7M Top-1 Error Rate: 21.4 |
| neural-architecture-search-on-oxford-102 | NAT-M1 | Accuracy (%): 97.5 FLOPS: 152M PARAMS: 3.3M |
| neural-architecture-search-on-oxford-102 | NAT-M2 | Accuracy (%): 97.9 FLOPS: 195M PARAMS: 3.4M |
| neural-architecture-search-on-oxford-102 | NAT-M3 | Accuracy (%): 98.1 FLOPS: 250M PARAMS: 3.7M |
| neural-architecture-search-on-oxford-102 | NAT-M4 | Accuracy (%): 98.3 FLOPS: 400M PARAMS: 4.2M |
| neural-architecture-search-on-oxford-iiit | NAT-M2 | Accuracy (%): 93.5 FLOPS: 306M PARAMS: 5.5M |
| neural-architecture-search-on-oxford-iiit | NAT-M4 | Accuracy (%): 94.3 FLOPS: 744M PARAMS: 8.5M |
| neural-architecture-search-on-oxford-iiit | NAT-M3 | Accuracy (%): 94.1 FLOPS: 471M PARAMS: 5.7M |
| neural-architecture-search-on-oxford-iiit | NAT-M1 | Accuracy (%): 91.8 FLOPS: 160M PARAMS: 4.0M |
| neural-architecture-search-on-stanford-cars | NAT-M3 | Accuracy (%): 92.6 FLOPS: 289M PARAMS: 3.5M |
| neural-architecture-search-on-stanford-cars | NAT-M4 | Accuracy (%): 92.9 FLOPS: 369M PARAMS: 3.7M |
| neural-architecture-search-on-stanford-cars | NAT-M1 | Accuracy (%): 90.0 FLOPS: 165M PARAMS: 2.4M |
| neural-architecture-search-on-stanford-cars | NAT-M2 | Accuracy (%): 92.2 FLOPS: 222M PARAMS: 2.7M |
| neural-architecture-search-on-stl-10 | NAT-M1 | Accuracy (%): 96.7 FLOPS: 240M PARAMS: 4.4M |
| neural-architecture-search-on-stl-10 | NAT-M4 | Accuracy (%): 97.9 FLOPS: 573M PARAMS: 7.5M |
| neural-architecture-search-on-stl-10 | NAT-M2 | Accuracy (%): 97.2 FLOPS: 303M PARAMS: 5.1M |
| neural-architecture-search-on-stl-10 | NAT-M3 | Accuracy (%): 97.8 FLOPS: 436M PARAMS: 7.5M |