3 个月前

基于预训练模型的互信息正则化域泛化

基于预训练模型的互信息正则化域泛化

摘要

领域泛化(Domain Generalization, DG)旨在仅利用有限的源域数据,训练出一个能够泛化至未见目标域的通用模型。以往的DG方法由于训练域与测试域之间存在显著的域差异,难以仅从源域中学习到域不变的表示。为此,本文提出一种新的DG目标重构方法,引入了基于“理想模型”(oracle model)的互信息约束,其中理想模型是指能够泛化至任意可能域的模型。通过使用预训练模型近似该理想模型,我们推导出一个可计算的变分下界,从而提出一种名为互信息正则化与理想模型(Mutual Information Regularization with Oracle, MIRO)的新方法。大量实验结果表明,MIRO显著提升了模型在分布外数据上的性能表现。此外,缩放实验进一步验证了:预训练模型规模越大,MIRO所带来的性能提升越显著。相关源代码已开源,地址为:https://github.com/kakaobrain/miro。

代码仓库

kakaobrain/miro
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
domain-generalization-on-domainnetMIRO (RegNetY-16GF, SWAD)
Average Accuracy: 60.7
domain-generalization-on-domainnetMIRO (ResNet-50, SWAD)
Average Accuracy: 47.0
domain-generalization-on-office-homeMIRO (ResNet-50, SWAD)
Average Accuracy: 72.4
domain-generalization-on-office-homeMIRO (RegNetY-16GF, SWAD)
Average Accuracy: 83.3
domain-generalization-on-pacs-2MIRO (ResNet-50, SWAD)
Average Accuracy: 88.4
domain-generalization-on-pacs-2MIRO (RegNetY-16GF, SWAD)
Average Accuracy: 96.8
domain-generalization-on-terraincognitaMIRO (ResNet-50, SWAD)
Average Accuracy: 52.9
domain-generalization-on-terraincognitaMIRO (RegNetY-16GF, SWAD)
Average Accuracy: 64.3
domain-generalization-on-vlcsMIRO (RegNetY-16GF, SWAD)
Average Accuracy: 81.7
domain-generalization-on-vlcsMIRO (ResNet-50, SWAD)
Average Accuracy: 79.6

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
基于预训练模型的互信息正则化域泛化 | 论文 | HyperAI超神经