
摘要
近期在语言模型领域的研究表明,训练大型Transformer模型可以推动自然语言处理应用的最先进水平。然而,由于内存限制,非常大的模型往往难以训练。在这项工作中,我们介绍了训练非常大型Transformer模型的技术,并实现了一种简单高效的层内模型并行方法,该方法使得训练包含数十亿参数的Transformer模型成为可能。我们的方法不需要新的编译器或库更改,与流水线模型并行方法正交且互补,并且可以通过在原生PyTorch中插入几个通信操作来完全实现。我们通过使用512个GPU收敛高达83亿参数的基于Transformer的模型来说明这一方法。与一个能够维持39万亿次浮点运算(相当于峰值浮点运算能力的30%)的强大单GPU基线相比,我们的方法在整个应用程序中实现了15.1千万亿次浮点运算,具有76%的扩展效率。为了证明大型语言模型可以进一步提升最先进水平(SOTA),我们训练了一个类似于GPT-2的83亿参数Transformer语言模型以及一个类似于BERT的39亿参数模型。我们展示了在类似BERT的模型中,随着模型规模的增长,对层归一化位置的仔细考虑对于提高性能至关重要。使用GPT-2模型,我们在WikiText103数据集上取得了最佳结果(困惑度为10.8,优于当前SOTA困惑度15.8),并在LAMBADA数据集上达到了66.5%的准确率(优于当前SOTA准确率63.2%)。我们的BERT模型在RACE数据集上也取得了最佳结果(准确率为90.9%,优于当前SOTA准确率89.4%)。
代码仓库
ezelikman/STaR
jax
GitHub 中提及
NVIDIA/Megatron-LM
官方
pytorch
GitHub 中提及
nvidia/transformerengine
jax
GitHub 中提及
qhduan/CPM-LM-TF2
pytorch
GitHub 中提及
THUDM/ProteinLM
pytorch
GitHub 中提及
facebookresearch/fairscale
pytorch
GitHub 中提及
kingoflolz/mesh-transformer-jax
jax
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| language-modelling-on-wikitext-103 | Megatron-LM | Number of params: 8300M Test perplexity: 10.81 |
| question-answering-on-piqa | MT-NLG 530B (0-shot) | Accuracy: 82.0 |
| reading-comprehension-on-race | Megatron-BERT | Accuracy: 89.5 Accuracy (High): 88.6 Accuracy (Middle): 91.8 |
| reading-comprehension-on-race | Megatron-BERT (ensemble) | Accuracy: 90.9 Accuracy (High): 90.0 Accuracy (Middle): 93.1 |