
摘要
状态空间模型(State Space Models, SSMs)在某些模态的序列建模任务中已展现出前沿性能,但在语言建模任务中仍逊于注意力机制(attention)。尽管SSMs的计算复杂度随序列长度近似线性增长(而非注意力机制的二次增长),但由于硬件利用率低下,其实际运行速度仍慢于Transformer模型。本文在两个方面取得进展:一是深入理解SSMs与注意力机制在语言建模中的表达能力差距,二是降低SSMs与注意力机制之间的硬件性能壁垒。首先,我们通过设计合成语言建模任务来探究SSMs与注意力机制之间的性能差距。研究发现,现有SSMs在两项关键能力上存在明显不足:难以回忆序列中较早出现的词元(token),以及难以对序列中不同位置的词元进行有效比较。为解决这一问题,我们提出一种新型SSM层——H3,该结构专门针对上述两种能力进行优化设计。实验表明,H3在合成语言任务中可达到与注意力机制相当的性能,并在OpenWebText数据集上将困惑度(PPL)提升至仅比Transformer低0.4的水平。此外,我们构建了一种参数量为1.25亿的混合模型(H3-attention),该模型仅保留两个注意力层,却在OpenWebText上比标准Transformer模型的PPL进一步降低1.0,表现更优。其次,为提升SSMs在现代硬件上的训练效率,我们提出FlashConv。FlashConv采用融合块快速傅里叶变换(block FFT)算法,在序列长度达8K以内时显著提升计算效率;同时引入一种新颖的状态传递机制,充分利用SSMs的递归特性,从而支持更长序列的高效处理。在长序列基准测试(Long-Range Arena)中,FlashConv实现2倍加速,使混合语言模型的文本生成速度比Transformer快2.4倍。借助FlashConv,我们将混合H3-attention语言模型扩展至27亿参数规模,并在Pile数据集上进行训练,初步结果表现优异:在困惑度方面优于Transformer模型,并在SuperGLUE基准测试的多数零样本(zero-shot)与少样本(few-shot)任务中实现超越Transformer的性能。
代码仓库
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| coreference-resolution-on-winograd-schema | Hybrid H3 125M (3-shot, logit scoring) | Accuracy: 43.3 |
| coreference-resolution-on-winograd-schema | H3 125M (0-shot, rank classification) | Accuracy: 61.5 |
| coreference-resolution-on-winograd-schema | H3 125M (3-shot, rank classification) | Accuracy: 63.5 |
| language-modelling-on-the-pile | Transformer 125M | Test perplexity: 10.7 |
| language-modelling-on-the-pile | Hybrid H3 125M | Test perplexity: 10.2 |
| language-modelling-on-wikitext-103 | Hybrid H3 (355M) | Number of params: 355M Test perplexity: 16.9 |
| language-modelling-on-wikitext-103 | Hybrid H3 125M | Test perplexity: 18.5 |
| language-modelling-on-wikitext-103 | Hybrid H3 (1.3B) | Number of params: 1300M Test perplexity: 12.5 |
| language-modelling-on-wikitext-103 | Hybrid H3 (2.7B) | Number of params: 2700M Test perplexity: 10.6 |
| language-modelling-on-wikitext-103 | Hybrid H3 (125M) | Number of params: 125M Test perplexity: 23.7 |
| natural-language-inference-on-rte | H3 125M (0-shot, rank classification) | Accuracy: 53.1% |
| natural-language-inference-on-rte | Hybrid H3 125M (3-shot, rank classification) | Accuracy: 58.1% |
| natural-language-inference-on-rte | Hybrid H3 125M (3-shot, logit scoring) | Accuracy: 58.1% |
| natural-language-inference-on-rte | H3 125M (3-shot, rank classification) | Accuracy: 52.3% |
| natural-language-inference-on-rte | Hybrid H3 125M (0-shot, logit scoring) | Accuracy: 59.2% |
| question-answering-on-boolq | Hybrid H3 125M (0-shot, logit scoring) | Accuracy: 59.6 |
| question-answering-on-boolq | Hybrid H3 2.7B (3-shot, logit scoring) | Accuracy: 60.6 |
| question-answering-on-boolq | Hybrid H3 1.3B (0-shot, logit scoring) | Accuracy: 61.7 |
| question-answering-on-boolq | Hybrid H3 125M (3-shot, logit scoring) | Accuracy: 56.1 |
| question-answering-on-boolq | Hybrid H3 125M (3-shot, rank classification) | Accuracy: 56.1 |
| question-answering-on-copa | Hybrid H3 125M (0-shot, rank classification) | Accuracy: 67 |
| question-answering-on-copa | H3 125M (0-shot, rank classification) | Accuracy: 51 |
| question-answering-on-copa | Hybrid H3 125M (0-shot, logit scoring) | Accuracy: 67 |
| question-answering-on-copa | Hybrid H3 2.7B (3-shot, logit scoring) | Accuracy: 77 |
| question-answering-on-copa | Hybrid H3 2.7B (0-shot, logit scoring) | Accuracy: 81 |
| question-answering-on-multirc | Hybrid H3 125M (3-shot, logit scoring) | EM: 48.9 |
| question-answering-on-multirc | Hybrid H3 355M (0-shot, logit scoring) | EM: 59.5 |
| question-answering-on-multirc | Hybrid H3 355M (3-shot, logit scoring) | EM: 59.7 |
| question-answering-on-multirc | Hybrid H3 125M (0-shot, logit scoring) | EM: 51.4 |
| word-sense-disambiguation-on-words-in-context | Hybrid H3 125M (0-shot, rank classification) | Accuracy: 51.4 |
| word-sense-disambiguation-on-words-in-context | Hybrid H3 125M (3-shot, logit scoring) | Accuracy: 49.1 |
| word-sense-disambiguation-on-words-in-context | Hybrid H3 125M (0-shot, logit scoring) | Accuracy: 51.4 |