
摘要
最近的半监督深度学习(深度SSL)方法大多采用了类似的范式:利用网络预测更新伪标签,并利用伪标签迭代更新网络参数。然而,这些方法缺乏理论支持,无法解释为什么预测结果是伪标签的良好候选者。在本文中,我们提出了一种原则性的端到端框架,称为深度解密(Deep Decipher,D2),用于半监督学习(SSL)。在D2框架内,我们证明了伪标签与网络预测之间存在指数链接函数关系,这为使用预测结果作为伪标签提供了理论依据。此外,我们还展示了通过网络预测更新伪标签会使它们变得不确定。为了解决这一问题,我们提出了一种称为重复再预测(Repetitive Reprediction,R2)的训练策略。最后,所提出的R2-D2方法在大规模ImageNet数据集上进行了测试,并比现有最先进方法提高了5个百分点。
代码仓库
DoctorKey/R2D2.pytorch
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| semi-supervised-image-classification-on-2 | R2-D2 (ResNet-18) | Top 5 Accuracy: 90.48% |
| semi-supervised-image-classification-on-cifar | R2-D2 (Shake-Shake) | Percentage error: 5.72 |
| semi-supervised-image-classification-on-cifar-2 | R2-D2 (CNN-13) | Percentage error: 32.87 |
| semi-supervised-image-classification-on-svhn | R2-D2 (CNN-13) | Accuracy: 96.36 |