3 个月前

并非所有标签都同等重要:通过标签分组与协同训练增强半监督学习

并非所有标签都同等重要:通过标签分组与协同训练增强半监督学习

摘要

伪标签(Pseudo-labeling)是半监督学习(Semi-Supervised Learning, SSL)中的关键组件。其核心思想是通过迭代地利用模型为未标记数据生成人工标签,并以此进行训练。现有各类伪标签方法的一个共同特性是:它们仅依据模型的预测结果进行标签决策,而未考虑类别之间的视觉相似性先验知识。本文指出,这一做法会降低伪标签的质量,因为其在伪标签数据池中难以准确反映视觉上相似的类别之间的关系。为此,我们提出了一种名为 SemCo 的新方法,该方法结合了标签语义信息与协同训练(co-training)机制,以缓解上述问题。SemCo 通过两个具有不同类别标签视角的分类器进行联合训练:其中一个分类器采用独热编码(one-hot)视角的标签,忽略类别间的潜在相似性;另一个分类器则采用分布式(distributed)标签视角,将可能相似的类别进行分组。随后,我们通过协同训练机制,促使两个分类器基于彼此的预测分歧进行学习。实验结果表明,所提方法在多种半监督学习任务中均取得了当前最优性能。例如,在 Mini-ImageNet 数据集上,仅使用 1000 个标注样本时,准确率提升了 5.6%。此外,我们的方法在达到最优性能时所需的批量大小(batch size)更小,训练迭代次数也更少,显著提升了训练效率。相关代码已开源,地址为:https://github.com/islam-nassar/semco。

代码仓库

islam-nassar/semco
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
semi-supervised-image-classification-on-cifarSemCo (μ=7)
Percentage error: 3.8±0.08
semi-supervised-image-classification-on-cifar-2SemCo (μ=7)
Percentage error: 24.45±0.12
semi-supervised-image-classification-on-miniSemCo (μ=3)
Accuracy: 53.99±0.93
semi-supervised-image-classification-on-miniSemCo (μ=7)
Accuracy: 50.54±2.20
semi-supervised-image-classification-on-mini-1SemCo (μ=7)
Accuracy: 57.22±0.35
semi-supervised-image-classification-on-mini-1SemCo (μ=3)
Accuracy: 58.75±0.76
semi-supervised-image-classification-on-mini-2SemCo (μ=7)
Accuracy: 40.65±0.23
semi-supervised-image-classification-on-mini-2SemCo (μ=3)
Accuracy: 44.65±0.71

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
并非所有标签都同等重要:通过标签分组与协同训练增强半监督学习 | 论文 | HyperAI超神经