3 个月前

一次剪枝,全量适配:稀疏预训练语言模型

一次剪枝,全量适配:稀疏预训练语言模型

摘要

基于Transformer的语言模型在自然语言处理的诸多应用中得到了广泛使用。然而,这类模型在计算效率上存在不足,且部署难度较大。近年来,为提升大型Transformer模型在目标硬件上的实现效率,已有大量压缩算法被提出。本文提出一种新方法,通过融合权重剪枝(weight pruning)与模型蒸馏(model distillation)技术,训练稀疏的预训练Transformer语言模型。这些稀疏的预训练模型在保持稀疏结构的同时,可广泛应用于各类迁移学习任务。我们以三种经典架构为例,分别构建了稀疏的预训练BERT-Base、BERT-Large与DistilBERT模型。实验表明,所训练的压缩稀疏模型在迁移至五个不同的下游自然语言处理任务时,仅产生极小的精度损失。此外,我们进一步采用感知量化训练(quantization-aware training)技术,将稀疏模型的权重压缩至8位精度。例如,在SQuADv1.1数据集上对稀疏预训练BERT-Large进行微调并量化至8位后,编码器部分实现了高达40倍的压缩比,且精度损失低于1%。据我们所知,该结果在BERT-Base、BERT-Large与DistilBERT三类模型中均达到了当前最优的压缩率与精度平衡。

代码仓库

基准测试

基准方法指标
natural-language-inference-on-multinli-devDistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8)
Matched: 78.8
Mismatched: 80.4
natural-language-inference-on-multinli-devBERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8)
Matched: 81.4
Mismatched: 82.51
natural-language-inference-on-multinli-devBERT-Base-uncased-PruneOFA (85% unstruct sparse)
Matched: 82.71
Mismatched: 83.67
natural-language-inference-on-multinli-devDistilBERT-uncased-PruneOFA (90% unstruct sparse)
Matched: 80.68
Mismatched: 81.47
natural-language-inference-on-multinli-devDistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8)
Matched: 80.66
Mismatched: 81.14
natural-language-inference-on-multinli-devBERT-Large-uncased-PruneOFA (90% unstruct sparse)
Matched: 83.74
Mismatched: 84.2
natural-language-inference-on-multinli-devBERT-Base-uncased-PruneOFA (90% unstruct sparse)
Matched: 81.45
Mismatched: 82.43
natural-language-inference-on-multinli-devBERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8)
Matched: 83.47
Mismatched: 84.08
natural-language-inference-on-multinli-devDistilBERT-uncased-PruneOFA (85% unstruct sparse)
Matched: 81.35
Mismatched: 82.03
question-answering-on-squad11-devDistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8)
EM: 75.62
F1: 83.87
question-answering-on-squad11-devBERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8)
EM: 83.22
F1: 90.02
question-answering-on-squad11-devBERT-Base-uncased-PruneOFA (85% unstruct sparse)
EM: 81.1
F1: 88.42
question-answering-on-squad11-devDistilBERT-uncased-PruneOFA (85% unstruct sparse)
EM: 78.1
F1: 85.82
question-answering-on-squad11-devBERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8)
EM: 80.84
F1: 88.24
question-answering-on-squad11-devBERT-Large-uncased-PruneOFA (90% unstruct sparse)
EM: 83.35
F1: 90.2
question-answering-on-squad11-devDistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8)
EM: 77.03
F1: 85.13
question-answering-on-squad11-devDistilBERT-uncased-PruneOFA (90% unstruct sparse)
EM: 76.91
F1: 84.82
question-answering-on-squad11-devBERT-Base-uncased-PruneOFA (90% unstruct sparse)
EM: 79.83
F1: 87.25

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供