
摘要
少样本学习的目标是在每个类别仅有有限数量的训练实例时,仍能训练出具有良好泛化能力的分类器。最近引入的元学习方法通过在大量多类分类任务中学习一个通用分类器,并将模型推广到新任务来解决这一问题。然而,即使采用了这样的元学习方法,新颖分类任务中的低数据量问题仍然存在。本文提出了一种新的元学习框架——传递传播网络(Transductive Propagation Network, TPN),用于传递推理,旨在一次性对整个测试集进行分类以缓解低数据量问题。具体而言,我们提出通过学习一个利用数据流形结构的图构建模块,从有标签的实例向无标签的测试实例传播标签。TPN 以端到端的方式联合学习特征嵌入参数和图构建。我们在多个基准数据集上验证了 TPN 的性能,结果表明它显著优于现有的少样本学习方法,并达到了最先进的水平。
代码仓库
csyanbin/TPN-pytorch
官方
pytorch
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| few-shot-image-classification-on-mini-12 | TPN (Higher Shot) | Accuracy: 38.4 |
| few-shot-image-classification-on-mini-12 | Label Propagation | Accuracy: 35.2 |
| few-shot-image-classification-on-mini-13 | TPN (Higher Shot) | Accuracy: 52.8 |
| few-shot-image-classification-on-mini-13 | Label Propagation | Accuracy: 51.2 |
| few-shot-image-classification-on-tiered-2 | TPN (Higher Shot) | Accuracy: 44.8 |
| few-shot-image-classification-on-tiered-2 | Label Propagation | Accuracy: 39.4 |
| few-shot-image-classification-on-tiered-3 | TPN (Higher Shot) | Accuracy: 59.4 |
| few-shot-image-classification-on-tiered-3 | Label Propagation | Accuracy: 57.9 |