3 个月前

基于视觉Transformer的学习不平衡数据

基于视觉Transformer的学习不平衡数据

摘要

真实世界的数据通常存在严重的类别不平衡问题,这会显著扭曲数据驱动的深度神经网络性能,使得长尾识别(Long-Tailed Recognition, LTR)成为一个极具挑战性的任务。现有LTR方法极少在长尾数据上从头训练视觉Transformer(Vision Transformers, ViTs),而直接使用现成的预训练权重往往导致不公平的比较。本文系统地研究了ViTs在LTR任务中的表现,并提出LiVT,一种仅使用长尾数据从零开始训练ViTs的新方法。基于观察发现,ViTs在长尾识别任务中面临更为严峻的挑战,我们引入掩码生成式预训练(Masked Generative Pretraining, MGP),以学习更具泛化能力的特征表示。通过充分且坚实的实验证据,我们证明MGP在鲁棒性方面显著优于传统的监督学习方式。此外,尽管二元交叉熵(Binary Cross Entropy, BCE)损失在ViTs上表现出色,但在长尾场景下仍面临性能瓶颈。为此,我们进一步提出平衡型BCE(Bal-BCE),其具有坚实的理论基础。具体而言,我们推导出Sigmoid函数的无偏扩展形式,并引入额外的logit边际补偿机制以实现更优的优化。所提出的Bal-BCE显著加速了ViTs在少数类上的收敛,仅需数个训练周期即可实现稳定性能。大量实验表明,结合MGP与Bal-BCE后,LiVT能够在不依赖任何额外数据的前提下,成功训练出高性能的ViTs,且显著超越现有最先进方法。例如,在iNaturalist 2018数据集上,我们的ViT-B模型在无任何复杂技巧(bells and whistles)的情况下,达到了81.0%的Top-1准确率。代码已开源,地址为:https://github.com/XuZhengzhuo/LiVT。

代码仓库

xuzhengzhuo/livt
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
long-tail-learning-on-cifar-10-lt-r-10ViT-B + Bal-BCE
Error Rate: 8.7
long-tail-learning-on-cifar-10-lt-r-10ViT-B + CB
Error Rate: 10.1
long-tail-learning-on-cifar-10-lt-r-10ViT-B + CE
Error Rate: 10.5
long-tail-learning-on-cifar-10-lt-r-10ViT-B + Bal-CE
Error Rate: 9.3
long-tail-learning-on-cifar-10-lt-r-10ViT-B + LDAM
Error Rate: 11.4

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
基于视觉Transformer的学习不平衡数据 | 论文 | HyperAI超神经