3 个月前

饥饿的河马:基于状态空间模型的语言建模

饥饿的河马:基于状态空间模型的语言建模

摘要

状态空间模型(State Space Models, SSMs)在某些模态的序列建模任务中已展现出前沿性能,但在语言建模任务中仍逊于注意力机制(attention)。尽管SSMs的计算复杂度随序列长度近似线性增长(而非注意力机制的二次增长),但由于硬件利用率低下,其实际运行速度仍慢于Transformer模型。本文在两个方面取得进展:一是深入理解SSMs与注意力机制在语言建模中的表达能力差距,二是降低SSMs与注意力机制之间的硬件性能壁垒。首先,我们通过设计合成语言建模任务来探究SSMs与注意力机制之间的性能差距。研究发现,现有SSMs在两项关键能力上存在明显不足:难以回忆序列中较早出现的词元(token),以及难以对序列中不同位置的词元进行有效比较。为解决这一问题,我们提出一种新型SSM层——H3,该结构专门针对上述两种能力进行优化设计。实验表明,H3在合成语言任务中可达到与注意力机制相当的性能,并在OpenWebText数据集上将困惑度(PPL)提升至仅比Transformer低0.4的水平。此外,我们构建了一种参数量为1.25亿的混合模型(H3-attention),该模型仅保留两个注意力层,却在OpenWebText上比标准Transformer模型的PPL进一步降低1.0,表现更优。其次,为提升SSMs在现代硬件上的训练效率,我们提出FlashConv。FlashConv采用融合块快速傅里叶变换(block FFT)算法,在序列长度达8K以内时显著提升计算效率;同时引入一种新颖的状态传递机制,充分利用SSMs的递归特性,从而支持更长序列的高效处理。在长序列基准测试(Long-Range Arena)中,FlashConv实现2倍加速,使混合语言模型的文本生成速度比Transformer快2.4倍。借助FlashConv,我们将混合H3-attention语言模型扩展至27亿参数规模,并在Pile数据集上进行训练,初步结果表现优异:在困惑度方面优于Transformer模型,并在SuperGLUE基准测试的多数零样本(zero-shot)与少样本(few-shot)任务中实现超越Transformer的性能。

代码仓库

hazyresearch/safari
pytorch
GitHub 中提及
lindermanlab/S5
jax
GitHub 中提及
hazyresearch/h3
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
coreference-resolution-on-winograd-schemaHybrid H3 125M (3-shot, logit scoring)
Accuracy: 43.3
coreference-resolution-on-winograd-schemaH3 125M (0-shot, rank classification)
Accuracy: 61.5
coreference-resolution-on-winograd-schemaH3 125M (3-shot, rank classification)
Accuracy: 63.5
language-modelling-on-the-pileTransformer 125M
Test perplexity: 10.7
language-modelling-on-the-pileHybrid H3 125M
Test perplexity: 10.2
language-modelling-on-wikitext-103Hybrid H3 (355M)
Number of params: 355M
Test perplexity: 16.9
language-modelling-on-wikitext-103Hybrid H3 125M
Test perplexity: 18.5
language-modelling-on-wikitext-103Hybrid H3 (1.3B)
Number of params: 1300M
Test perplexity: 12.5
language-modelling-on-wikitext-103Hybrid H3 (2.7B)
Number of params: 2700M
Test perplexity: 10.6
language-modelling-on-wikitext-103Hybrid H3 (125M)
Number of params: 125M
Test perplexity: 23.7
natural-language-inference-on-rteH3 125M (0-shot, rank classification)
Accuracy: 53.1%
natural-language-inference-on-rteHybrid H3 125M (3-shot, rank classification)
Accuracy: 58.1%
natural-language-inference-on-rteHybrid H3 125M (3-shot, logit scoring)
Accuracy: 58.1%
natural-language-inference-on-rteH3 125M (3-shot, rank classification)
Accuracy: 52.3%
natural-language-inference-on-rteHybrid H3 125M (0-shot, logit scoring)
Accuracy: 59.2%
question-answering-on-boolqHybrid H3 125M (0-shot, logit scoring)
Accuracy: 59.6
question-answering-on-boolqHybrid H3 2.7B (3-shot, logit scoring)
Accuracy: 60.6
question-answering-on-boolqHybrid H3 1.3B (0-shot, logit scoring)
Accuracy: 61.7
question-answering-on-boolqHybrid H3 125M (3-shot, logit scoring)
Accuracy: 56.1
question-answering-on-boolqHybrid H3 125M (3-shot, rank classification)
Accuracy: 56.1
question-answering-on-copaHybrid H3 125M (0-shot, rank classification)
Accuracy: 67
question-answering-on-copaH3 125M (0-shot, rank classification)
Accuracy: 51
question-answering-on-copaHybrid H3 125M (0-shot, logit scoring)
Accuracy: 67
question-answering-on-copaHybrid H3 2.7B (3-shot, logit scoring)
Accuracy: 77
question-answering-on-copaHybrid H3 2.7B (0-shot, logit scoring)
Accuracy: 81
question-answering-on-multircHybrid H3 125M (3-shot, logit scoring)
EM: 48.9
question-answering-on-multircHybrid H3 355M (0-shot, logit scoring)
EM: 59.5
question-answering-on-multircHybrid H3 355M (3-shot, logit scoring)
EM: 59.7
question-answering-on-multircHybrid H3 125M (0-shot, logit scoring)
EM: 51.4
word-sense-disambiguation-on-words-in-contextHybrid H3 125M (0-shot, rank classification)
Accuracy: 51.4
word-sense-disambiguation-on-words-in-contextHybrid H3 125M (3-shot, logit scoring)
Accuracy: 49.1
word-sense-disambiguation-on-words-in-contextHybrid H3 125M (0-shot, logit scoring)
Accuracy: 51.4

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
饥饿的河马:基于状态空间模型的语言建模 | 论文 | HyperAI超神经