
摘要
长尾数据集的特点是头部类别包含远多于尾部类别的训练样本,这导致识别模型倾向于偏向头部类别,产生偏差。加权损失(Weighted loss)是缓解该问题最常用的方法之一。近期研究提出,相较于传统上依赖类别频率,类别难度(class-difficulty)可能是更优的权重分配依据。然而,先前工作采用启发式方法来量化难度,而我们通过实证发现,最优的难度量化形式会因数据集特性不同而异。为此,本文提出 Difficulty-Net,该方法在元学习(meta-learning)框架下,自动学习预测各类别的难度。为使模型在与其他类别相互关联的上下文中合理估计类别难度,我们引入两个关键概念:相对难度(relative difficulty)与驱动损失(driver loss)。其中,相对难度使 Difficulty-Net 在计算某一类别难度时能够综合考虑其他类别的影响;而驱动损失则对学习过程起到关键引导作用,确保模型朝有意义的方向优化。在多个主流长尾数据集上的大量实验表明,所提出方法具有显著有效性,并在多个数据集上取得了当前最优(state-of-the-art)性能。
代码仓库
hitachi-rd-cv/Difficulty_Net
官方
pytorch
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| long-tail-learning-on-cifar-100-lt-r-10 | Difficulty-Net | Error Rate: 34.78 |
| long-tail-learning-on-cifar-100-lt-r-100 | Difficulty-Net | Error Rate: 47.04 |
| long-tail-learning-on-cifar-100-lt-r-50 | Difficulty-Net | Error Rate: 43.1 |
| long-tail-learning-on-imagenet-lt | Difficulty-Net (ResNet-50 w/o using RandAugment, single model) | Top-1 Accuracy: 54.0 |
| long-tail-learning-on-imagenet-lt | Difficulty-Net (ResNet-10 w/o using RandAugment, single model | Top-1 Accuracy: 44.6 |
| long-tail-learning-on-imagenet-lt | Difficulty-Net (ResNet-50 using RandAugment, single model) | Top-1 Accuracy: 57.4 |
| long-tail-learning-on-places-lt | Difficulty-Net (ResNet-152) | Top-1 Accuracy: 41.7 |