
摘要
深度学习正在推动许多计算机视觉应用达到新的水平。然而,它依赖于大规模标注数据集,而如何捕捉现实世界数据的无约束特性仍然是一个未解决的问题。半监督学习(SSL)通过利用大量未标注数据来补充标注训练数据,从而降低标注成本。传统的SSL方法假设未标注数据与标注数据来自相同的分布。最近,引入了一种更为现实的SSL问题,称为开放世界SSL(Open-World SSL),在这种情况下,未标注数据可能包含来自未知类别的样本。在本文中,我们提出了一种基于伪标签的新方法来应对开放世界的SSL问题。该方法的核心在于利用样本不确定性并结合关于类别分布的先验知识,为属于已知和未知类别的未标注数据生成可靠的类别分布感知伪标签。我们的广泛实验展示了该方法在多个基准数据集上的有效性,其性能在七个不同的数据集上显著优于现有最先进方法,包括CIFAR-100(约17%)、ImageNet-100(约5%)和Tiny ImageNet(约9%)。我们还强调了该方法在解决新类别发现任务中的灵活性,展示了其在处理不平衡数据时的稳定性,并补充了一种估计新类别数量的技术。
代码仓库
nayeemrizve/trssl
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| open-world-semi-supervised-learning-on-1 | TRSSL (ResNet-50) | All accuracy (10% Labeled): 75.4 Novel accuracy (10% Labeled): 67.8 Seen accuracy (10% Labeled): 82.6 |
| open-world-semi-supervised-learning-on-cifar | TRSSL (ResNet-18) | All accuracy (10% Labeled): 92.2 Novel accuracy (10% Labeled): 89.6 Seen accuracy (10% Labeled): 94.9 |
| open-world-semi-supervised-learning-on-cifar-1 | TRSSL (ResNet-18) | All accuracy (10% Labeled): 60.3 Novel accuracy (10% Labeled): 52.1 Seen accuracy (10% Labeled): 68.5 |