3 个月前

通过寻找平坦极小值提升联邦学习的泛化能力

通过寻找平坦极小值提升联邦学习的泛化能力

摘要

在联邦学习设置中训练的模型往往面临性能下降且泛化能力不足的问题,尤其是在面对数据异构场景时更为显著。本文从损失函数的几何结构及Hessian矩阵特征谱的角度,深入探究了这一现象,揭示了模型泛化能力缺失与最优解“尖锐性”之间的内在联系。受先前研究中损失曲面尖锐性与泛化差距之间关联的启发,我们提出:i)在客户端采用尖锐性感知最小化(Sharpness-Aware Minimization, SAM)或其自适应版本(Adaptive SAM, ASAM)进行本地训练;ii)在服务器端采用随机权重平均(Stochastic Weight Averaging, SWA)进行模型聚合,这两种策略可显著提升联邦学习的泛化性能,并有效缩小与集中式训练模型之间的性能差距。通过在损失值均匀低的参数邻域中寻找最优解,模型能够收敛至更平坦的极小值点,从而在同质与异质两种场景下均显著提升泛化能力。实验结果表明,上述优化方法在多个基准视觉数据集(如CIFAR10/100、Landmarks-User-160k、IDDA)以及多种任务(包括大规模分类、语义分割、领域泛化)上均表现出优异的性能与广泛的有效性。

代码仓库

debcaldarola/fedsam
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
federated-learning-on-cifar-100-alpha-0-10FedAvg
ACC@1-100Clients: 36.74
federated-learning-on-cifar-100-alpha-0-10FedSAM + SWA
ACC@1-100Clients: 39.51
federated-learning-on-cifar-100-alpha-0-10FedASAM
ACC@1-100Clients: 39.76
federated-learning-on-cifar-100-alpha-0-10FedASAM + SWA
ACC@1-100Clients: 42.64
federated-learning-on-cifar-100-alpha-0-10FedSAM
ACC@1-100Clients: 36.93
federated-learning-on-cifar-100-alpha-0-20FedASAM
ACC@1-100Clients: 40.81
federated-learning-on-cifar-100-alpha-0-20FedAvg
ACC@1-100Clients: 38.59
federated-learning-on-cifar-100-alpha-0-20FedSAM + SWA
ACC@1-100Clients: 39.24
federated-learning-on-cifar-100-alpha-0-20FedSAM
ACC@1-100Clients: 38.56
federated-learning-on-cifar-100-alpha-0-20FedASAM + SWA
ACC@1-100Clients: 41.62
federated-learning-on-cifar-100-alpha-0-5FedSAM
ACC@1-100Clients: 31.04
federated-learning-on-cifar-100-alpha-0-5FedAvg
ACC@1-100Clients: 30.25
federated-learning-on-cifar-100-alpha-0-5FedASAM + SWA
ACC@1-100Clients: 42.01
federated-learning-on-cifar-100-alpha-0-5FedASAM
ACC@1-100Clients: 36.04
federated-learning-on-cifar-100-alpha-0-5FedSAM + SWA
ACC@1-100Clients: 39.3
federated-learning-on-cifar-100-alpha-0-5-10FedSAM + SWA
ACC@1-100Clients: 46.76
federated-learning-on-cifar-100-alpha-0-5-10FedAvg
ACC@1-100Clients: 41.27
federated-learning-on-cifar-100-alpha-0-5-10FedASAM
ACC@1-100Clients: 46.58
federated-learning-on-cifar-100-alpha-0-5-10FedASAM + SWA
ACC@1-100Clients: 48.72
federated-learning-on-cifar-100-alpha-0-5-10FedSAM
ACC@1-100Clients: 44.84
federated-learning-on-cifar-100-alpha-0-5-20FedASAM + SWA
ACC@1-100Clients: 48.27
federated-learning-on-cifar-100-alpha-0-5-20FedSAM
ACC@1-100Clients: 46.05
federated-learning-on-cifar-100-alpha-0-5-20FedAvg
ACC@1-100Clients: 42.17
federated-learning-on-cifar-100-alpha-0-5-20FedASAM
ACC@1-100Clients: 47.78
federated-learning-on-cifar-100-alpha-0-5-20FedSAM + SWA
ACC@1-100Clients: 46.47
federated-learning-on-cifar-100-alpha-0-5-5FedASAM + SWA
ACC@1-100Clients: 49.17
federated-learning-on-cifar-100-alpha-0-5-5FedASAM
ACC@1-100Clients: 45.61
federated-learning-on-cifar-100-alpha-0-5-5FedSAM + SWA
ACC@1-100Clients: 47.96
federated-learning-on-cifar-100-alpha-0-5-5FedSAM
ACC@1-100Clients: 44.73
federated-learning-on-cifar-100-alpha-0-5-5FedAvg
ACC@1-100Clients: 40.43
federated-learning-on-cifar-100-alpha-1000-10FedSAM + SWA
ACC@1-100Clients: 53.67
federated-learning-on-cifar-100-alpha-1000-10FedSAM
ACC@1-100Clients: 53.39
federated-learning-on-cifar-100-alpha-1000-10FedASAM + SWA
ACC@1-100Clients: 54.79
federated-learning-on-cifar-100-alpha-1000-10FedASAM
ACC@1-100Clients: 54.97
federated-learning-on-cifar-100-alpha-1000-10FedAvg
ACC@1-100Clients: 50.25
federated-learning-on-cifar-100-alpha-1000-20FedASAM + SWA
ACC@1-100Clients: 54.1
federated-learning-on-cifar-100-alpha-1000-20FedSAM + SWA
ACC@1-100Clients: 54.36
federated-learning-on-cifar-100-alpha-1000-20FedAvg
ACC@1-100Clients: 50.66
federated-learning-on-cifar-100-alpha-1000-20FedASAM
ACC@1-100Clients: 54.5
federated-learning-on-cifar-100-alpha-1000-20FedSAM
ACC@1-100Clients: 53.97
federated-learning-on-cifar-100-alpha-1000-5FedASAM + SWA
ACC@1-100Clients: 53.86
federated-learning-on-cifar-100-alpha-1000-5FedAvg
ACC@1-100Clients: 49.92
federated-learning-on-cifar-100-alpha-1000-5FedSAM
ACC@1-100Clients: 54.01
federated-learning-on-cifar-100-alpha-1000-5FedSAM + SWA
ACC@1-100Clients: 53.9
federated-learning-on-cifar-100-alpha-1000-5FedASAM
ACC@1-100Clients: 54.81
federated-learning-on-cityscapesFedAvg
mIoU: 38.65
federated-learning-on-cityscapesFedSAM
mIoU: 41.22
federated-learning-on-cityscapesSiloBN + ASAM
mIoU: 49.75
federated-learning-on-cityscapesFedASAM
mIoU: 42.27
federated-learning-on-cityscapesFedASAM + SWA
mIoU: 43.02
federated-learning-on-cityscapesFedSAM + SWA
mIoU: 43.42
federated-learning-on-cityscapesSiloBN + SAM
mIoU: 49.1
federated-learning-on-cityscapesSiloBN
mIoU: 45.96
federated-learning-on-cityscapesFedAvg + SWA
mIoU: 42.48
federated-learning-on-landmarks-user-160kFedSAM
Acc@1-1262Clients: 63.72
federated-learning-on-landmarks-user-160kFedASAM
Acc@1-1262Clients: 64.23
federated-learning-on-landmarks-user-160kFedASAM + SWA
Acc@1-1262Clients: 68.32
federated-learning-on-landmarks-user-160kFedAvg + SWA
Acc@1-1262Clients: 67.52
federated-learning-on-landmarks-user-160kFedAvg
Acc@1-1262Clients: 61.91
federated-learning-on-landmarks-user-160kFedSAM + SWA
Acc@1-1262Clients: 68.12
image-classification-on-cifar-100-alpha-0-20FedAvgM + ASAM + SWA
ACC@1-100Clients: 51.58

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
通过寻找平坦极小值提升联邦学习的泛化能力 | 论文 | HyperAI超神经