
摘要
本文提出了一种名为基于支持样本预测视图分配(Predicting View Assignments with Support Samples, PAWS)的新颖学习方法。该方法通过训练模型最小化一致性损失(consistency loss),确保同一未标记样本的不同视图被分配相似的伪标签。伪标签采用非参数化方式生成,即通过将图像视图的表示与一组随机采样的已标记图像表示进行比较来获得。视图表示与已标记表示之间的距离被用于对类别标签进行加权,我们将其解释为一种软伪标签。通过这种非参数化方式引入已标记样本,PAWS 将自监督学习方法(如 BYOL 和 SwAV)中使用的距离度量损失扩展至半监督学习场景。尽管该方法设计简洁,但在多种网络架构上均优于现有的半监督学习方法,在仅使用 ImageNet 数据集 10% 或 1% 标注数据的情况下,分别在 ResNet-50 上取得了 75.5% 和 66.5% 的 Top-1 准确率,刷新了该任务的最新性能纪录。此外,PAWS 所需的训练时间仅为此前最优方法的 1/4 到 1/12。
代码仓库
facebookresearch/msn
pytorch
GitHub 中提及
sayakpaul/PAWS-TF
tf
GitHub 中提及
beresandras/semisupervised-classification-keras
tf
GitHub 中提及
facebookresearch/suncet
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| image-classification-on-imagenet | PAWS (ResNet-50, 10% labels) | Top 1 Accuracy: 75.5% |
| image-classification-on-imagenet | PAWS (ResNet-50, 1% labels) | Top 1 Accuracy: 66.5% |
| semi-supervised-image-classification-on-1 | PAWS (ResNet-50) | Top 1 Accuracy: 66.5% |
| semi-supervised-image-classification-on-1 | PAWS (ResNet-50 2x) | Top 1 Accuracy: 69.6% |
| semi-supervised-image-classification-on-1 | PAWS (ResNet-50 4x) | Top 1 Accuracy: 69.9% |
| semi-supervised-image-classification-on-2 | PAWS (ResNet-50) | Top 1 Accuracy: 75.5% |
| semi-supervised-image-classification-on-2 | PAWS (ResNet-50 2x) | Top 1 Accuracy: 77.8% |
| semi-supervised-image-classification-on-2 | PAWS (ResNet-50 4x) | Top 1 Accuracy: 79.0% |
| semi-supervised-image-classification-on-cifar | PAWS-NN (WRN-28-2) | Percentage error: 4.0 ± 0.25 |