
摘要
大规模Transformer模型在多项任务上均能取得当前最优性能,但其训练过程往往成本高昂,尤其是在处理长序列时尤为显著。本文提出两种技术以提升Transformer模型的效率。首先,我们用基于局部敏感哈希(locality-sensitive hashing)的注意力机制替代传统的点积注意力,将计算复杂度从O($L^2$)降低至O($L\log L$),其中$L$表示序列长度。其次,我们采用可逆残差层(reversible residual layers)替代标准残差结构,使得在训练过程中只需存储一次激活值,而非传统方法中的$N$次($N$为网络层数)。由此构建的模型——Reformer,在性能上与传统Transformer模型相当,同时在处理长序列时展现出显著更低的内存占用和更快的运行速度。
代码仓库
lucashueda/long_sentence_transformer
pytorch
GitHub 中提及
lucidrains/reformer-pytorch
pytorch
huggingface/transformers
pytorch
GitHub 中提及
sliao-mi-luku/NLP-Chatbot-Reformer-Trax
pytorch
GitHub 中提及
Rick-McCoy/Reformer-pytorch
pytorch
google/trax/tree/master/trax/models/reformer
官方
jax
GitHub 中提及
sliao-mi-luku/Chatbot-Reformer
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| d4rl-on-d4rl | Reformer | Average Reward: 63.9 |
| image-generation-on-imagenet-64x64 | Reformer (6 layers) | Bits per dim: 3.740 |
| image-generation-on-imagenet-64x64 | Reformer (12 layers) | Bits per dim: 3.710 |
| language-modelling-on-wikitext-103 | Reformer 125M | Test perplexity: 26.0 |
| open-domain-question-answering-on-searchqa | Locality-Sensitive Hashing | EM: 66.0 |
| question-answering-on-natural-questions-long | Locality-Sensitive Hashing | F1: 75.5 |
| question-answering-on-quasart-t | Locality-Sensitive Hashing | EM: 53.2 |