3 个月前

金字塔对抗训练提升ViT性能

金字塔对抗训练提升ViT性能

摘要

对抗性数据增强是视觉Transformer(Vision Transformer, ViT)具备强大泛化能力的关键因素之一。其中一种典型的数据增强技术是对抗训练(Adversarial Training, AT),然而已有大量研究表明,该方法通常会导致模型在干净样本上的准确率下降。针对这一问题,本文提出了一种简单而高效的方法——金字塔对抗训练(Pyramid Adversarial Training, PyramidAT),以全面提升ViT的整体性能。我们进一步结合了一种“匹配”的Dropout与随机深度(stochastic depth)正则化策略,即在干净样本与对抗样本上采用相同的Dropout和随机深度配置。这一设计类似于在卷积神经网络(CNN)中通过AdvProp实现的性能提升(该方法不直接适用于ViT),但我们的PyramidAT首次打破了ViT及其相关架构在分布内准确率与分布外鲁棒性之间的权衡关系。在仅使用ImageNet-1K数据训练的情况下,PyramidAT使ViT-B模型在ImageNet干净样本上的准确率提升了1.82个百分点。同时,在7项ImageNet鲁棒性评估指标上,性能也实现了显著提升,绝对提升幅度介于1.76%至15.68%之间。该方法在不引入额外数据的前提下,刷新了ImageNet-C(mCE为41.42)、ImageNet-R(53.92%)和ImageNet-Sketch(41.04%)三项基准的最新最优性能记录,仅依赖ViT-B/16主干网络与所提出的PyramidAT方法。相关代码已公开,可访问:pyramidat.github.io。

基准测试

基准方法指标
domain-generalization-on-imagenet-aPyramid Adversarial Training Improves ViT (384x384)
Top-1 accuracy %: 36.41
domain-generalization-on-imagenet-aPyramid Adversarial Training Improves ViT (Im21k)
Top-1 accuracy %: 62.44
domain-generalization-on-imagenet-cPyramid Adversarial Training Improves ViT
mean Corruption Error (mCE): 41.42
domain-generalization-on-imagenet-cPyramid Adversarial Training Improves ViT (Im21k)
Number of params: 87M
mean Corruption Error (mCE): 36.80
domain-generalization-on-imagenet-rPyramid Adversarial Training Improves ViT (Im21k)
Top-1 Error Rate: 42.16
domain-generalization-on-imagenet-rPyramid Adversarial Training Improves ViT
Top-1 Error Rate: 46.08
domain-generalization-on-imagenet-sketchPyramid Adversarial Training Improves ViT
Top-1 accuracy: 41.04
domain-generalization-on-imagenet-sketchPyramid Adversarial Training Improves ViT (Im21k)
Top-1 accuracy: 46.03
image-classification-on-objectnetRegViT (RandAug)
Top-1 Accuracy: 29.3
image-classification-on-objectnetMLP-Mixer + Pixel
Top-1 Accuracy: 24.75
image-classification-on-objectnetDiscrete ViT
Top-1 Accuracy: 29.95
image-classification-on-objectnetRegViT (RandAug) + Adv Pixel
Top-1 Accuracy: 30.11
image-classification-on-objectnetMLP-Mixer
Top-1 Accuracy: 25.9
image-classification-on-objectnetRegViT (RandAug) + Random Pixel
Top-1 Accuracy: 28.72
image-classification-on-objectnetRegViT (RandAug) + Adv Pyramid
Top-1 Accuracy: 32.92
image-classification-on-objectnetRegViT on 384x384 + Random Pyramid
Top-1 Accuracy: 34.83
image-classification-on-objectnetRegViT (RandAug) + Random Pyramid
Top-1 Accuracy: 29.41
image-classification-on-objectnetDiscrete ViT + Pixel
Top-1 Accuracy: 30.98
image-classification-on-objectnetRegViT on 384x384 + Random Pixel
Top-1 Accuracy: 34.12
image-classification-on-objectnetViT
Top-1 Accuracy: 17.36
image-classification-on-objectnetViT + MixUp
Top-1 Accuracy: 25.65
image-classification-on-objectnetViT-B/16 (512x512) + Pyramid
Top-1 Accuracy: 49.39
image-classification-on-objectnetMLP-Mixer + Pyramid
Top-1 Accuracy: 28.6
image-classification-on-objectnetDiscrete ViT + Pyramid
Top-1 Accuracy: 30.28
image-classification-on-objectnetViT-B/16 (512x512)
Top-1 Accuracy: 46.68
image-classification-on-objectnetRegViT on 384x384 + Adv Pixel
Top-1 Accuracy: 37.41
image-classification-on-objectnetRegViT on 384x384
Top-1 Accuracy: 35.59
image-classification-on-objectnetViT-B/16 (512x512) + Pixel
Top-1 Accuracy: 47.53
image-classification-on-objectnetViT + CutMix
Top-1 Accuracy: 21.61
image-classification-on-objectnetRegViT on 384x384 + Adv Pyramid
Top-1 Accuracy: 39.79

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
金字塔对抗训练提升ViT性能 | 论文 | HyperAI超神经