David KruegerEthan CaballeroJoern-Henrik JacobsenAmy ZhangJonathan BinasDinghuai ZhangRemi Le PriolAaron Courville

摘要
分布偏移(distributional shift)是将机器学习预测系统从实验室环境迁移至现实世界时面临的主要障碍之一。为应对这一挑战,我们假设训练域之间的变化能够代表测试阶段可能遇到的变化,但同时认为测试阶段的分布偏移在幅度上可能更为极端。特别地,我们证明:通过减小训练域间风险的差异,可以降低模型对多种极端分布偏移的敏感性,包括输入中同时包含因果与反因果要素这一极具挑战性的场景。我们提出一种名为风险外推(Risk Extrapolation, REx)的方法,将其视为对扩展域扰动集合(MM-REx)进行鲁棒优化的一种形式,并进一步提出一种基于训练风险方差惩罚的简化变体(V-REx)。理论分析表明,REx的各类变体不仅能够恢复目标变量的因果机制,同时还能在输入分布发生变化(即“协变量偏移”)的情况下提供一定的鲁棒性。通过合理权衡对因果性导致的分布偏移与协变量偏移的鲁棒性,REx在两类偏移共存的情形下,能够超越如不变风险最小化(Invariant Risk Minimization, IRM)等现有方法的性能表现。
代码仓库
capybaralet/REx_code_release
官方
pytorch
thuml/Transfer-Learning-Library
pytorch
GitHub 中提及
facebookresearch/DomainBed
pytorch
GitHub 中提及
lingxiaoyuan/ood_mechanics
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| domain-generalization-on-pacs-2 | VREx (Alexnet) | Average Accuracy: 71.14 |
| image-classification-on-colored-mnist-with | MLP-REx | Accuracy : 68.70 |