HyperAIHyperAI

Command Palette

Search for a command to run...

3 months ago

BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning

Zhi Hou Baosheng Yu Dacheng Tao

BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning

Abstract

Despite the success of deep neural networks, there are still many challenges in deep representation learning due to the data scarcity issues such as data imbalance, unseen distribution, and domain shift. To address the above-mentioned issues, a variety of methods have been devised to explore the sample relationships in a vanilla way (i.e., from the perspectives of either the input or the loss function), failing to explore the internal structure of deep neural networks for learning with sample relationships. Inspired by this, we propose to enable deep neural networks themselves with the ability to learn the sample relationships from each mini-batch. Specifically, we introduce a batch transformer module or BatchFormer, which is then applied into the batch dimension of each mini-batch to implicitly explore sample relationships during training. By doing this, the proposed method enables the collaboration of different samples, e.g., the head-class samples can also contribute to the learning of the tail classes for long-tailed recognition. Furthermore, to mitigate the gap between training and testing, we share the classifier between with or without the BatchFormer during training, which can thus be removed during testing. We perform extensive experiments on over ten datasets and the proposed method achieves significant improvements on different data scarcity applications without any bells and whistles, including the tasks of long-tailed recognition, compositional zero-shot learning, domain generalization, and contrastive learning. Code will be made publicly available at https://github.com/zhihou7/BatchFormer.

Code Repositories

zhihou7/batchformer
Official
pytorch
Mentioned in GitHub

Benchmarks

BenchmarkMethodologyMetrics
domain-generalization-on-pacs-2BatchFormer(ResNet-50, SWAD)
Average Accuracy: 88.6
long-tail-learning-on-cifar-100-lt-r-100Paco + BatchFormer
Error Rate: 47.6
long-tail-learning-on-cifar-100-lt-r-100Balanced + BatchFormer
Error Rate: 48.3
long-tail-learning-on-imagenet-ltBatchFormer(ResNet-50, RIDE)
Top-1 Accuracy: 55.7
long-tail-learning-on-imagenet-ltBatchFormer(ResNet-50, PaCo)
Top-1 Accuracy: 57.4
long-tail-learning-on-inaturalist-2018BatchFormer(ResNet-50, RIDE)
Top-1 Accuracy: 74.1%

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing
Get Started

Hyper Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp