
摘要
在独立同分布(i.i.d.)测试集上,过参数化的神经网络通常能够实现较高的平均准确率,但在数据中的非典型群体上却持续表现不佳(例如,由于学习了在整体数据上成立但不适用于特定群体的虚假相关性)。分布鲁棒优化(Distributionally Robust Optimization, DRO)提供了一种方法,使模型能够最小化预定义群体集合中的最坏情况训练损失。然而,我们发现,将传统的组别DRO(group DRO)直接应用于过参数化神经网络会失效:这类模型能够完美拟合训练数据,且只要平均训练损失趋近于零,最坏情况下的训练损失也随之趋近于零。事实上,最差群体性能的下降根源在于某些群体上的泛化能力不足。通过将组别DRO模型与更强的正则化策略相结合——例如采用强于常规的L2正则化或提前停止(early stopping)——我们显著提升了最差群体的准确率,在自然语言推理任务以及两个图像识别任务上,最差群体准确率提升了10至40个百分点,同时仍保持较高的平均准确率。我们的研究结果表明,在过参数化情形下,正则化对于最差群体的泛化至关重要,即使其对平均泛化并非必需。最后,我们提出了一种具有收敛性保证的随机优化算法,可高效训练组别DRO模型。
代码仓库
ssagawa/overparam_spur_corr
pytorch
GitHub 中提及
ys-zong/medfair
pytorch
facebookresearch/DomainBed
pytorch
GitHub 中提及
haoxiang-wang/isr
pytorch
GitHub 中提及
orparask/VS-Loss
pytorch
GitHub 中提及
kohpangwei/group_DRO
官方
pytorch
GitHub 中提及
yangarbiter/dp-dg
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| domain-generalization-on-nico-animal | DRO (Resnet-18) | Accuracy: 77.61 |
| domain-generalization-on-nico-vehicle | DRO (Resnet-18) | Accuracy: 77.61 |
| domain-generalization-on-pacs-2 | GroupDRO (Resnet-50, DomainBed) | Average Accuracy: 84.4 |