HyperAIHyperAI

Command Palette

Search for a command to run...

5 months ago

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Kaidi Cao; Colin Wei; Adrien Gaidon; Nikos Arechiga; Tengyu Ma

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Abstract

Deep learning algorithms can fare poorly when the training dataset suffers from heavy class-imbalance but the testing criterion requires good generalization on less frequent classes. We design two novel methods to improve performance in such scenarios. First, we propose a theoretically-principled label-distribution-aware margin (LDAM) loss motivated by minimizing a margin-based generalization bound. This loss replaces the standard cross-entropy objective during training and can be applied with prior strategies for training with class-imbalance such as re-weighting or re-sampling. Second, we propose a simple, yet effective, training schedule that defers re-weighting until after the initial stage, allowing the model to learn an initial representation while avoiding some of the complications associated with re-weighting or re-sampling. We test our methods on several benchmark vision tasks including the real-world imbalanced dataset iNaturalist 2018. Our experiments show that either of these methods alone can already improve over existing techniques and their combination achieves even better performance gains.

Code Repositories

feidfoe/AdjustBnd4Imbalance
pytorch
Mentioned in GitHub
j3soon/arxiv-utils
Mentioned in GitHub
kaidic/LDAM-DRW
Official
pytorch
Mentioned in GitHub
karurb92/ldam_str_bn
tf
Mentioned in GitHub
orparask/VS-Loss
pytorch
Mentioned in GitHub
ihaeyong/maximum-margin-ldam
pytorch
Mentioned in GitHub

Benchmarks

BenchmarkMethodologyMetrics
long-tail-learning-on-cifar-10-lt-r-10Empirical Risk Minimization (ERM, CE)
Error Rate: 13.61
long-tail-learning-on-cifar-10-lt-r-10Class-balanced Resampling
Error Rate: 13.21
long-tail-learning-on-cifar-10-lt-r-10LDAM-DRW
Error Rate: 11.84
long-tail-learning-on-cifar-10-lt-r-100LDAM-DRW
Error Rate: 22.97
long-tail-learning-on-cifar-100-lt-r-10LDAM-DRW
Error Rate: 41.29
long-tail-learning-on-cifar-100-lt-r-100LDAM-DRW
Error Rate: 57.96
long-tail-learning-on-coco-mltLDAM(ResNet-50)
Average mAP: 40.53
long-tail-learning-on-voc-mltLDAM(ResNet-50)
Average mAP: 70.73
long-tail-learning-with-class-descriptors-onLDAM
Long-Tailed Accuracy: 64.1
Per-Class Accuracy: 50.1
long-tail-learning-with-class-descriptors-on-1LDAM
Long-Tailed Accuracy: 36.4
Per-Class Accuracy: 29.8
long-tail-learning-with-class-descriptors-on-2LDAM
Long-Tailed Accuracy: 93.5
Per-Class Accuracy: 69.1

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing
Get Started

Hyper Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp
Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss | Papers | HyperAI