Yi TayMostafa DehghaniVamsi AribandiJai GuptaPhilip PhamZhen QinDara BahriDa-Cheng JuanDonald Metzler

摘要
本文提出了一种基于Transformer的全向表征模型——OmniNet。在OmniNet中,每个Token不再局限于传统的水平感受野,而是被允许关注网络中所有其他Token,从而实现对整个网络宽度与深度的全局信息交互。这一机制可被理解为一种极端或高强度的注意力机制,其感受野覆盖整个网络的全部空间维度。为实现这一全向注意力,模型采用一个元学习器(meta-learner)来学习注意力权重,该元学习器本质上是一个基于自注意力机制的模型。为缓解全感受野注意力带来的高计算开销,本文引入高效的自注意力机制作为元学习器,包括基于核函数的方法(Choromanski 等)、低秩注意力(Wang 等)以及 Big Bird(Zaheer 等)等。在自回归语言建模(LM1B、C4)、机器翻译、长程依赖基准测试(Long Range Arena, LRA)以及图像识别等多个任务上进行了大量实验。结果表明,OmniNet在各项任务中均取得了显著提升,尤其在LM1B、WMT’14 En-De/En-Fr以及Long Range Arena任务上达到了当前最优(state-of-the-art)性能。此外,在视觉Transformer(Vision Transformer)中引入全向表征,显著提升了图像识别任务的表现,无论是在少样本学习(few-shot learning)还是微调(fine-tuning)设置下均展现出明显优势。
代码仓库
lucidrains/omninet-pytorch
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| language-modelling-on-one-billion-word | OmniNetT (Large) | Number of params: 100M PPL: 21.5 |
| language-modelling-on-one-billion-word | OmniNetB (Large) | PPL: 22 |
| language-modelling-on-one-billion-word | OmniNetP (Large) | Number of params: 100M PPL: 21.6 |
| machine-translation-on-wmt2014-english-french | OmniNetP | BLEU score: 42.6 Hardware Burden: Operations per network pass: |
| machine-translation-on-wmt2014-english-german | OmniNetP | BLEU score: 29.8 Hardware Burden: Operations per network pass: |
| machine-translation-on-wmt2017-chinese | OmniNetP | BLEU: 23.0 |
| machine-translation-on-wmt2017-english | OmniNetP | BLEU: 20.9 |
| machine-translation-on-wmt2017-english-french | OmniNetP | BLEU: 43.1 |
| machine-translation-on-wmt2017-english-german | OmniNetP | BLEU: 29.0 |
| machine-translation-on-wmt2017-russian | OmniNetP | BLEU: 36.2 |