
摘要
对抗性数据增强是视觉Transformer(Vision Transformer, ViT)具备强大泛化能力的关键因素之一。其中一种典型的数据增强技术是对抗训练(Adversarial Training, AT),然而已有大量研究表明,该方法通常会导致模型在干净样本上的准确率下降。针对这一问题,本文提出了一种简单而高效的方法——金字塔对抗训练(Pyramid Adversarial Training, PyramidAT),以全面提升ViT的整体性能。我们进一步结合了一种“匹配”的Dropout与随机深度(stochastic depth)正则化策略,即在干净样本与对抗样本上采用相同的Dropout和随机深度配置。这一设计类似于在卷积神经网络(CNN)中通过AdvProp实现的性能提升(该方法不直接适用于ViT),但我们的PyramidAT首次打破了ViT及其相关架构在分布内准确率与分布外鲁棒性之间的权衡关系。在仅使用ImageNet-1K数据训练的情况下,PyramidAT使ViT-B模型在ImageNet干净样本上的准确率提升了1.82个百分点。同时,在7项ImageNet鲁棒性评估指标上,性能也实现了显著提升,绝对提升幅度介于1.76%至15.68%之间。该方法在不引入额外数据的前提下,刷新了ImageNet-C(mCE为41.42)、ImageNet-R(53.92%)和ImageNet-Sketch(41.04%)三项基准的最新最优性能记录,仅依赖ViT-B/16主干网络与所提出的PyramidAT方法。相关代码已公开,可访问:pyramidat.github.io。
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| domain-generalization-on-imagenet-a | Pyramid Adversarial Training Improves ViT (384x384) | Top-1 accuracy %: 36.41 |
| domain-generalization-on-imagenet-a | Pyramid Adversarial Training Improves ViT (Im21k) | Top-1 accuracy %: 62.44 |
| domain-generalization-on-imagenet-c | Pyramid Adversarial Training Improves ViT | mean Corruption Error (mCE): 41.42 |
| domain-generalization-on-imagenet-c | Pyramid Adversarial Training Improves ViT (Im21k) | Number of params: 87M mean Corruption Error (mCE): 36.80 |
| domain-generalization-on-imagenet-r | Pyramid Adversarial Training Improves ViT (Im21k) | Top-1 Error Rate: 42.16 |
| domain-generalization-on-imagenet-r | Pyramid Adversarial Training Improves ViT | Top-1 Error Rate: 46.08 |
| domain-generalization-on-imagenet-sketch | Pyramid Adversarial Training Improves ViT | Top-1 accuracy: 41.04 |
| domain-generalization-on-imagenet-sketch | Pyramid Adversarial Training Improves ViT (Im21k) | Top-1 accuracy: 46.03 |
| image-classification-on-objectnet | RegViT (RandAug) | Top-1 Accuracy: 29.3 |
| image-classification-on-objectnet | MLP-Mixer + Pixel | Top-1 Accuracy: 24.75 |
| image-classification-on-objectnet | Discrete ViT | Top-1 Accuracy: 29.95 |
| image-classification-on-objectnet | RegViT (RandAug) + Adv Pixel | Top-1 Accuracy: 30.11 |
| image-classification-on-objectnet | MLP-Mixer | Top-1 Accuracy: 25.9 |
| image-classification-on-objectnet | RegViT (RandAug) + Random Pixel | Top-1 Accuracy: 28.72 |
| image-classification-on-objectnet | RegViT (RandAug) + Adv Pyramid | Top-1 Accuracy: 32.92 |
| image-classification-on-objectnet | RegViT on 384x384 + Random Pyramid | Top-1 Accuracy: 34.83 |
| image-classification-on-objectnet | RegViT (RandAug) + Random Pyramid | Top-1 Accuracy: 29.41 |
| image-classification-on-objectnet | Discrete ViT + Pixel | Top-1 Accuracy: 30.98 |
| image-classification-on-objectnet | RegViT on 384x384 + Random Pixel | Top-1 Accuracy: 34.12 |
| image-classification-on-objectnet | ViT | Top-1 Accuracy: 17.36 |
| image-classification-on-objectnet | ViT + MixUp | Top-1 Accuracy: 25.65 |
| image-classification-on-objectnet | ViT-B/16 (512x512) + Pyramid | Top-1 Accuracy: 49.39 |
| image-classification-on-objectnet | MLP-Mixer + Pyramid | Top-1 Accuracy: 28.6 |
| image-classification-on-objectnet | Discrete ViT + Pyramid | Top-1 Accuracy: 30.28 |
| image-classification-on-objectnet | ViT-B/16 (512x512) | Top-1 Accuracy: 46.68 |
| image-classification-on-objectnet | RegViT on 384x384 + Adv Pixel | Top-1 Accuracy: 37.41 |
| image-classification-on-objectnet | RegViT on 384x384 | Top-1 Accuracy: 35.59 |
| image-classification-on-objectnet | ViT-B/16 (512x512) + Pixel | Top-1 Accuracy: 47.53 |
| image-classification-on-objectnet | ViT + CutMix | Top-1 Accuracy: 21.61 |
| image-classification-on-objectnet | RegViT on 384x384 + Adv Pyramid | Top-1 Accuracy: 39.79 |