
摘要
在领域泛化(Domain Generalization, DG)设置中,模型在独立训练于若干训练领域后,往往在分布外的测试领域上表现出极其不稳定的性能,而优化过程中的随机性(如随机种子)在其中起到了显著作用。这使得深度学习模型在现实场景中难以可靠应用。我们首先揭示,这种不稳定性不仅存在于不同模型之间,甚至在单个模型的训练优化轨迹中也持续存在。为此,我们提出一种简单的模型平均协议,该方法不仅能显著提升领域泛化能力,还能有效降低随机性的影响,其核心机制在于增强域内验证准确率与域外测试准确率之间的秩相关性,这对实现可靠的早停策略至关重要。基于这一发现,我们进一步表明:相较于实践中常见的对未平均模型进行集成(ensemble),通过对独立训练运行中获得的移动平均模型(Ensemble of Averages, EoA)进行集成,能够进一步提升性能。我们从理论上对模型平均与集成带来的性能提升进行解释,将经典的偏差-方差权衡(Bias-Variance Trade-off)框架推广至领域泛化场景,揭示了其内在机制。在DomainBed基准测试中,当使用预训练的ResNet-50时,该平均模型集成方法实现了平均68.0%的准确率,较原始的ERM方法(无平均或集成)提升了约4%;当使用预训练的RegNetY-16GF时,平均准确率达到76.6%,较基准ERM方法提升6%。我们的代码已开源,地址为:https://github.com/salesforce/ensemble-of-averages。
代码仓库
salesforce/ensemble-of-averages
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| domain-generalization-on-domainnet | Ensemble of Averages (RegNetY-16GF) | Average Accuracy: 60.9 |
| domain-generalization-on-domainnet | Ensemble of Averages (ResNet-50) | Average Accuracy: 47.4 |
| domain-generalization-on-domainnet | Ensemble of Averages (ResNeXt-50 32x4d) | Average Accuracy: 54.6 |
| domain-generalization-on-office-home | Ensemble of Averages (RegNetY-16GF) | Average Accuracy: 83.9 |
| domain-generalization-on-office-home | Ensemble of Averages (ResNeXt-50 32x4d) | Average Accuracy: 80.2 |
| domain-generalization-on-office-home | Ensemble of Averages (ResNet-50) | Average Accuracy: 72.5 |
| domain-generalization-on-pacs-2 | Ensemble of Averages (RegNetY-16GF) | Average Accuracy: 95.8 |
| domain-generalization-on-pacs-2 | Ensemble of Averages (ResNet-50) | Average Accuracy: 88.6 |
| domain-generalization-on-pacs-2 | Ensemble of Averages (ResNeXt-50 32x4d) | Average Accuracy: 93.2 |
| domain-generalization-on-terraincognita | Ensemble of Averages (ResNet-50) | Average Accuracy: 52.3 |
| domain-generalization-on-terraincognita | Ensemble of Averages (RegNetY-16GF) | Average Accuracy: 61.1 |
| domain-generalization-on-terraincognita | Ensemble of Averages (ResNeXt-50 32x4d) | Average Accuracy: 55.2 |
| domain-generalization-on-vlcs | Ensemble of Averages (RegNetY-16GF) | Average Accuracy: 81.1 |
| domain-generalization-on-vlcs | Ensemble of Averages (ResNeXt-50 32x4d) | Average Accuracy: 80.4 |
| domain-generalization-on-vlcs | Ensemble of Averages (ResNet-50) | Average Accuracy: 79.1 |