3 个月前

用于高效提升泛化能力的锐度感知最小化

用于高效提升泛化能力的锐度感知最小化

摘要

在当今高度过参数化的模型中,训练损失值对模型泛化能力的保证极为有限。事实上,仅优化训练损失值(这是目前普遍采用的做法)极易导致模型质量不佳。受先前研究中关于损失曲面几何结构与泛化能力之间关联的启发,我们提出了一种新颖且高效的优化方法,即同时最小化损失值与损失的“尖锐度”(sharpness)。具体而言,我们的方法——尖锐度感知最小化(Sharpness-Aware Minimization, SAM)——旨在寻找位于损失值普遍较低的邻域内的模型参数;这一目标可形式化为一个高效的梯度下降可求解的极小-极大优化问题。实验结果表明,SAM在多种基准数据集(如CIFAR-10、CIFAR-100、ImageNet以及微调任务)和模型架构上均显著提升了模型的泛化性能,为其中若干任务带来了新的最先进(SOTA)表现。此外,我们发现SAM天然具备与专门针对标签噪声设计的先进方法相当的鲁棒性,能够有效应对标签噪声问题。相关代码已开源,地址为:\url{https://github.com/google-research/sam}。

基准测试

基准方法指标
fine-grained-image-classification-on-birdsnapEffNet-L2 (SAM)
Accuracy: 90.07%
fine-grained-image-classification-on-fgvcEffNet-L2 (SAM)
Top-1 Error Rate: 4.82
fine-grained-image-classification-on-food-101EffNet-L2 (SAM)
Accuracy: 96.18
fine-grained-image-classification-on-oxford-2EffNet-L2 (SAM)
Accuracy: 97.10
Top-1 Error Rate: 2.90%
fine-grained-image-classification-on-stanfordEffNet-L2 (SAM)
Accuracy: 95.96%
image-classification-on-cifar-100PyramidNet (SAM)
Percentage correct: 89.7
image-classification-on-cifar-100CNN39
Percentage correct: 42.64
image-classification-on-cifar-100EffNet-L2 (SAM)
Percentage correct: 96.08
image-classification-on-cifar-100CNN36
Percentage correct: 36.07
image-classification-on-flowers-102EffNet-L2 (SAM)
Accuracy: 99.65%
image-classification-on-imagenetEfficientNet-L2-475 (SAM)
Number of params: 480M
Top 1 Accuracy: 88.61%
image-classification-on-imagenetResNet-152 (SAM)
Top 1 Accuracy: 81.6%

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
用于高效提升泛化能力的锐度感知最小化 | 论文 | HyperAI超神经