3 个月前

平均值集成:提升模型选择并增强领域泛化性能

平均值集成:提升模型选择并增强领域泛化性能

摘要

在领域泛化(Domain Generalization, DG)设置中,模型在独立训练于若干训练领域后,往往在分布外的测试领域上表现出极其不稳定的性能,而优化过程中的随机性(如随机种子)在其中起到了显著作用。这使得深度学习模型在现实场景中难以可靠应用。我们首先揭示,这种不稳定性不仅存在于不同模型之间,甚至在单个模型的训练优化轨迹中也持续存在。为此,我们提出一种简单的模型平均协议,该方法不仅能显著提升领域泛化能力,还能有效降低随机性的影响,其核心机制在于增强域内验证准确率与域外测试准确率之间的秩相关性,这对实现可靠的早停策略至关重要。基于这一发现,我们进一步表明:相较于实践中常见的对未平均模型进行集成(ensemble),通过对独立训练运行中获得的移动平均模型(Ensemble of Averages, EoA)进行集成,能够进一步提升性能。我们从理论上对模型平均与集成带来的性能提升进行解释,将经典的偏差-方差权衡(Bias-Variance Trade-off)框架推广至领域泛化场景,揭示了其内在机制。在DomainBed基准测试中,当使用预训练的ResNet-50时,该平均模型集成方法实现了平均68.0%的准确率,较原始的ERM方法(无平均或集成)提升了约4%;当使用预训练的RegNetY-16GF时,平均准确率达到76.6%,较基准ERM方法提升6%。我们的代码已开源,地址为:https://github.com/salesforce/ensemble-of-averages。

代码仓库

salesforce/ensemble-of-averages
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
domain-generalization-on-domainnetEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 60.9
domain-generalization-on-domainnetEnsemble of Averages (ResNet-50)
Average Accuracy: 47.4
domain-generalization-on-domainnetEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 54.6
domain-generalization-on-office-homeEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 83.9
domain-generalization-on-office-homeEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 80.2
domain-generalization-on-office-homeEnsemble of Averages (ResNet-50)
Average Accuracy: 72.5
domain-generalization-on-pacs-2Ensemble of Averages (RegNetY-16GF)
Average Accuracy: 95.8
domain-generalization-on-pacs-2Ensemble of Averages (ResNet-50)
Average Accuracy: 88.6
domain-generalization-on-pacs-2Ensemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 93.2
domain-generalization-on-terraincognitaEnsemble of Averages (ResNet-50)
Average Accuracy: 52.3
domain-generalization-on-terraincognitaEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 61.1
domain-generalization-on-terraincognitaEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 55.2
domain-generalization-on-vlcsEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 81.1
domain-generalization-on-vlcsEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 80.4
domain-generalization-on-vlcsEnsemble of Averages (ResNet-50)
Average Accuracy: 79.1

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
平均值集成:提升模型选择并增强领域泛化性能 | 论文 | HyperAI超神经