3 个月前

FlexMatch:通过课程伪标签提升半监督学习

FlexMatch:通过课程伪标签提升半监督学习

摘要

近期提出的FixMatch在多数半监督学习(SSL)基准测试中取得了最先进性能。然而,与其他现代SSL算法类似,FixMatch采用对所有类别均固定的常数阈值来筛选参与训练的无标签数据,未能考虑不同类别在学习过程中所处的不同状态及学习难度的差异。为解决这一问题,我们提出课程伪标签(Curriculum Pseudo Labeling, CPL),一种基于模型当前学习状态动态利用无标签数据的课程学习方法。CPL的核心思想是在每个训练时间步灵活调整不同类别的阈值,从而允许具有信息量的无标签数据及其伪标签通过,以促进模型学习。CPL无需引入额外参数,也不增加前向或反向传播的计算开销。我们将CPL应用于FixMatch,提出改进后的算法FlexMatch。FlexMatch在多种SSL基准测试中均达到最先进水平,尤其在标注数据极度稀缺或任务具有挑战性的情况下表现尤为突出。例如,在CIFAR-100和STL-10数据集上,当每类仅有4个标注样本时,FlexMatch相比FixMatch分别实现了13.96%和18.96%的错误率降低。此外,CPL显著提升了模型的收敛速度:FlexMatch仅需FixMatch约1/5的训练时间,即可达到甚至超越其性能。进一步实验表明,CPL可轻松适配至其他主流SSL算法,并显著提升其性能。相关代码已开源,地址为:https://github.com/TorchSSL/TorchSSL。

代码仓库

torchssl/torchssl
官方
pytorch
GitHub 中提及
beandkay/sequencematch
pytorch
GitHub 中提及

基准测试

基准方法指标
semi-supervised-image-classification-on-2FlexMatch
Top 1 Accuracy: 64.79%
Top 5 Accuracy: 86.04%
semi-supervised-image-classification-on-cifarFlexMatch
Percentage error: 4.19±0.01
semi-supervised-image-classification-on-cifar-2FlexMatch
Percentage error: 21.90±0.15
semi-supervised-image-classification-on-cifar-6FlexMatch
Percentage error: 4.8±0.06
semi-supervised-image-classification-on-cifar-7FlexMatch
Percentage error: 4.99±0.16
semi-supervised-image-classification-on-cifar-8FlexMatch
Percentage error: 39.94±1.62
semi-supervised-image-classification-on-cifar-9FlexMatch
Percentage error: 26.49±0.20

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
FlexMatch:通过课程伪标签提升半监督学习 | 论文 | HyperAI超神经