3 个月前

视觉Transformer在无需预训练或强数据增强的情况下超越ResNet

视觉Transformer在无需预训练或强数据增强的情况下超越ResNet

摘要

视觉Transformer(Vision Transformers, ViTs)与MLP-Mixers标志着在用通用神经网络架构替代人工设计的特征或归纳偏置方面的进一步探索。现有方法通过大规模数据训练(如大规模预训练和/或反复的强数据增强)来提升模型性能,但仍存在优化相关问题(例如对初始化和学习率的敏感性)。为此,本文从损失函数几何结构的角度出发,研究ViTs与MLP-Mixers,旨在提升模型在训练阶段的数据效率以及在推理阶段的泛化能力。通过可视化分析与Hessian矩阵分析发现,收敛后的模型存在极其尖锐的局部极小值。通过引入一种近期提出的锐度感知优化器(sharpness-aware optimizer)以增强模型的平滑性,我们显著提升了ViTs与MLP-Mixers在多种任务上的准确率与鲁棒性,涵盖监督学习、对抗学习、对比学习及迁移学习等场景。例如,在仅采用简单Inception风格预处理的情况下,ViT-B/16与Mixer-B/16在ImageNet上的Top-1准确率分别提升了5.3%与11.0%。进一步分析表明,模型平滑性的提升主要归因于前几层中活跃神经元的稀疏化。由此得到的ViTs在从零开始训练(无大规模预训练或强数据增强)的情况下,其性能已超越同等规模与吞吐量的ResNet。模型检查点已公开于:\url{https://github.com/google-research/vision_transformer}。

代码仓库

google-research/vision_transformer
官方
jax
GitHub 中提及
ttt496/VisionTransformer
jax
GitHub 中提及

基准测试

基准方法指标
domain-generalization-on-imagenet-cMixer-B/8-SAM
Top 1 Accuracy: 48.9
domain-generalization-on-imagenet-cResNet-152x2-SAM
Top 1 Accuracy: 55
domain-generalization-on-imagenet-cViT-B/16-SAM
Top 1 Accuracy: 56.5
domain-generalization-on-imagenet-rMixer-B/8-SAM
Top-1 Error Rate: 76.5
domain-generalization-on-imagenet-rResNet-152x2-SAM
Top-1 Error Rate: 71.9
domain-generalization-on-imagenet-rViT-B/16-SAM
Top-1 Error Rate: 73.6
fine-grained-image-classification-on-oxford-2ResNet-50-SAM
Accuracy: 91.6
fine-grained-image-classification-on-oxford-2Mixer-B/16- SAM
Accuracy: 92.5
fine-grained-image-classification-on-oxford-2Mixer-S/16- SAM
Accuracy: 88.7
fine-grained-image-classification-on-oxford-2ViT-B/16- SAM
Accuracy: 93.1
fine-grained-image-classification-on-oxford-2ViT-S/16- SAM
Accuracy: 92.9
fine-grained-image-classification-on-oxford-2ResNet-152-SAM
Accuracy: 93.3
image-classification-on-cifar-10ResNet-50-SAM
Percentage correct: 97.4
image-classification-on-cifar-10Mixer-S/16- SAM
Percentage correct: 96.1
image-classification-on-cifar-10ViT-S/16- SAM
Percentage correct: 98.2
image-classification-on-cifar-10ViT-B/16- SAM
Percentage correct: 98.6
image-classification-on-cifar-10ResNet-152-SAM
Percentage correct: 98.2
image-classification-on-cifar-10Mixer-B/16- SAM
Percentage correct: 97.8
image-classification-on-cifar-100Mixer-B/16- SAM
Percentage correct: 86.4
image-classification-on-cifar-100ViT-B/16- SAM
Percentage correct: 89.1
image-classification-on-cifar-100ResNet-50-SAM
Percentage correct: 85.2
image-classification-on-cifar-100ViT-S/16- SAM
Percentage correct: 87.6
image-classification-on-cifar-100Mixer-S/16- SAM
Percentage correct: 82.4
image-classification-on-flowers-102Mixer-S/16- SAM
Accuracy: 87.9
image-classification-on-flowers-102ResNet-152-SAM
Accuracy: 91.1
image-classification-on-flowers-102ViT-S/16- SAM
Accuracy: 91.5
image-classification-on-flowers-102ViT-B/16- SAM
Accuracy: 91.8
image-classification-on-flowers-102Mixer-B/16- SAM
Accuracy: 90
image-classification-on-flowers-102ResNet-50-SAM
Accuracy: 90
image-classification-on-imagenetViT-B/16-SAM
Number of params: 87M
Top 1 Accuracy: 79.9%
image-classification-on-imagenetResNet-152x2-SAM
Number of params: 236M
Top 1 Accuracy: 81.1%
image-classification-on-imagenetMixer-B/8-SAM
Number of params: 64M
Top 1 Accuracy: 79%
image-classification-on-imagenet-realResNet-152x2-SAM
Accuracy: 86.4%
image-classification-on-imagenet-realViT-B/16-SAM
Accuracy: 85.2%
image-classification-on-imagenet-realMixer-B/8-SAM
Accuracy: 84.4%
image-classification-on-imagenet-v2Mixer-B/8-SAM
Top 1 Accuracy: 65.5
image-classification-on-imagenet-v2ViT-B/16-SAM
Top 1 Accuracy: 67.5
image-classification-on-imagenet-v2ResNet-152x2-SAM
Top 1 Accuracy: 69.6

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
视觉Transformer在无需预训练或强数据增强的情况下超越ResNet | 论文 | HyperAI超神经