
摘要
在当今高度过参数化的模型中,训练损失值对模型泛化能力的保证极为有限。事实上,仅优化训练损失值(这是目前普遍采用的做法)极易导致模型质量不佳。受先前研究中关于损失曲面几何结构与泛化能力之间关联的启发,我们提出了一种新颖且高效的优化方法,即同时最小化损失值与损失的“尖锐度”(sharpness)。具体而言,我们的方法——尖锐度感知最小化(Sharpness-Aware Minimization, SAM)——旨在寻找位于损失值普遍较低的邻域内的模型参数;这一目标可形式化为一个高效的梯度下降可求解的极小-极大优化问题。实验结果表明,SAM在多种基准数据集(如CIFAR-10、CIFAR-100、ImageNet以及微调任务)和模型架构上均显著提升了模型的泛化性能,为其中若干任务带来了新的最先进(SOTA)表现。此外,我们发现SAM天然具备与专门针对标签噪声设计的先进方法相当的鲁棒性,能够有效应对标签噪声问题。相关代码已开源,地址为:\url{https://github.com/google-research/sam}。
代码仓库
ys-zong/medfair
pytorch
Jannoshh/simple-sam
tf
GitHub 中提及
denizyuret/playground
pytorch
GitHub 中提及
rollovd/LookSAM
pytorch
GitHub 中提及
moskomule/sam.pytorch
pytorch
GitHub 中提及
Janus-Shiau/SAM-tf2
tf
GitHub 中提及
google-research/sam
官方
jax
simon20010923/DDAMFN
pytorch
GitHub 中提及
Yuheon/Sharp-Aware-Minimization
pytorch
GitHub 中提及
mhassann22/GCSAM
pytorch
GitHub 中提及
NiMlr/pynlqn
GitHub 中提及
davda54/sam
pytorch
GitHub 中提及
borealisai/perturbed-forgetting
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| fine-grained-image-classification-on-birdsnap | EffNet-L2 (SAM) | Accuracy: 90.07% |
| fine-grained-image-classification-on-fgvc | EffNet-L2 (SAM) | Top-1 Error Rate: 4.82 |
| fine-grained-image-classification-on-food-101 | EffNet-L2 (SAM) | Accuracy: 96.18 |
| fine-grained-image-classification-on-oxford-2 | EffNet-L2 (SAM) | Accuracy: 97.10 Top-1 Error Rate: 2.90% |
| fine-grained-image-classification-on-stanford | EffNet-L2 (SAM) | Accuracy: 95.96% |
| image-classification-on-cifar-100 | PyramidNet (SAM) | Percentage correct: 89.7 |
| image-classification-on-cifar-100 | CNN39 | Percentage correct: 42.64 |
| image-classification-on-cifar-100 | EffNet-L2 (SAM) | Percentage correct: 96.08 |
| image-classification-on-cifar-100 | CNN36 | Percentage correct: 36.07 |
| image-classification-on-flowers-102 | EffNet-L2 (SAM) | Accuracy: 99.65% |
| image-classification-on-imagenet | EfficientNet-L2-475 (SAM) | Number of params: 480M Top 1 Accuracy: 88.61% |
| image-classification-on-imagenet | ResNet-152 (SAM) | Top 1 Accuracy: 81.6% |