Mitchell WortsmanGabriel IlharcoSamir Yitzhak GadreRebecca RoelofsRaphael Gontijo-LopesAri S. MorcosHongseok NamkoongAli FarhadiYair CarmonSimon KornblithLudwig Schmidt

摘要
传统的提升模型准确率的方法通常包括两个步骤:(1)使用不同的超参数训练多个模型;(2)从这些模型中选择在预留验证集上表现最佳的单一模型,其余模型则被丢弃。本文在微调大型预训练模型的背景下重新审视了这一流程的第二步。在该场景中,微调后的模型往往聚集于单一低误差区域。我们发现,对采用不同超参数配置微调得到的多个模型的权重进行平均,通常能够显著提升模型的准确率与鲁棒性。与传统集成方法不同,我们可以在不增加推理开销或内存消耗的前提下,对大量模型进行平均——我们将这一方法所得结果称为“模型汤”(model soups)。在微调诸如CLIP、ALIGN以及在JFT数据集上预训练的ViT-G等大型预训练模型时,我们的“模型汤”方法在ImageNet上的表现显著优于超参数搜索中选出的最佳单个模型。由此获得的ViT-G模型在ImageNet上实现了90.94%的Top-1准确率,创造了新的最先进水平。此外,我们进一步证明,该模型汤方法可推广至多种图像分类与自然语言处理任务,能够提升模型在分布外数据上的表现,并增强其在新下游任务上的零样本性能。最后,我们从理论上分析了权重平均与logit集成性能相似性的根源,发现其与损失函数的平坦性以及预测置信度密切相关,并通过实验证实了这一理论关系。相关代码已开源,地址为:https://github.com/mlfoundations/model-soups。
代码仓库
shallowlearn/sportsreid
pytorch
GitHub 中提及
Burf/ModelSoups
tf
GitHub 中提及
facebookresearch/ModelRatatouille
pytorch
GitHub 中提及
flowritecom/flow-merge
pytorch
GitHub 中提及
mlfoundations/model-soups
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| domain-generalization-on-imagenet-a | Model soups (BASIC-L) | Top-1 accuracy %: 94.17 |
| domain-generalization-on-imagenet-a | Model soups (ViT-G/14) | Top-1 accuracy %: 92.67 |
| domain-generalization-on-imagenet-r | Model soups (ViT-G/14) | Top-1 Error Rate: 4.54 |
| domain-generalization-on-imagenet-r | Model soups (BASIC-L) | Top-1 Error Rate: 3.90 |
| domain-generalization-on-imagenet-sketch | Model soups (ViT-G/14) | Top-1 accuracy: 74.24 |
| domain-generalization-on-imagenet-sketch | Model soups (BASIC-L) | Top-1 accuracy: 77.18 |
| image-classification-on-imagenet | Model soups (ViT-G/14) | Number of params: 1843M Top 1 Accuracy: 90.94% |
| image-classification-on-imagenet | Model soups (BASIC-L) | Number of params: 2440M Top 1 Accuracy: 90.98% |
| image-classification-on-imagenet-real | Model soups (ViT-G/14) | Accuracy: 91.20% Params: 1843M |
| image-classification-on-imagenet-real | Model soups (BASIC-L) | Accuracy: 91.03% Params: 2440M |
| image-classification-on-imagenet-real | Baseline (ViT-G/14) | Accuracy: 91.78% |
| image-classification-on-imagenet-v2 | Model soups (ViT-G/14) | Top 1 Accuracy: 84.22 |
| image-classification-on-imagenet-v2 | Model soups (BASIC-L) | Top 1 Accuracy: 84.63 |
| image-classification-on-objectnet | Baseline (ViT-G/14) | Top-1 Accuracy: 79.03 |
| image-classification-on-objectnet | Model soups (ViT-G/14) | Top-1 Accuracy: 78.52 |
| unsupervised-domain-adaptation-on-imagenet-r | Model soups (ViT-G/14) | Top 1 Error: 4.54 |