
摘要
基于Transformer的语言模型在自然语言处理的诸多应用中得到了广泛使用。然而,这类模型在计算效率上存在不足,且部署难度较大。近年来,为提升大型Transformer模型在目标硬件上的实现效率,已有大量压缩算法被提出。本文提出一种新方法,通过融合权重剪枝(weight pruning)与模型蒸馏(model distillation)技术,训练稀疏的预训练Transformer语言模型。这些稀疏的预训练模型在保持稀疏结构的同时,可广泛应用于各类迁移学习任务。我们以三种经典架构为例,分别构建了稀疏的预训练BERT-Base、BERT-Large与DistilBERT模型。实验表明,所训练的压缩稀疏模型在迁移至五个不同的下游自然语言处理任务时,仅产生极小的精度损失。此外,我们进一步采用感知量化训练(quantization-aware training)技术,将稀疏模型的权重压缩至8位精度。例如,在SQuADv1.1数据集上对稀疏预训练BERT-Large进行微调并量化至8位后,编码器部分实现了高达40倍的压缩比,且精度损失低于1%。据我们所知,该结果在BERT-Base、BERT-Large与DistilBERT三类模型中均达到了当前最优的压缩率与精度平衡。
代码仓库
intellabs/model-compression-research-package
官方
pytorch
GitHub 中提及
intel/intel-extension-for-transformers
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| natural-language-inference-on-multinli-dev | DistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8) | Matched: 78.8 Mismatched: 80.4 |
| natural-language-inference-on-multinli-dev | BERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8) | Matched: 81.4 Mismatched: 82.51 |
| natural-language-inference-on-multinli-dev | BERT-Base-uncased-PruneOFA (85% unstruct sparse) | Matched: 82.71 Mismatched: 83.67 |
| natural-language-inference-on-multinli-dev | DistilBERT-uncased-PruneOFA (90% unstruct sparse) | Matched: 80.68 Mismatched: 81.47 |
| natural-language-inference-on-multinli-dev | DistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8) | Matched: 80.66 Mismatched: 81.14 |
| natural-language-inference-on-multinli-dev | BERT-Large-uncased-PruneOFA (90% unstruct sparse) | Matched: 83.74 Mismatched: 84.2 |
| natural-language-inference-on-multinli-dev | BERT-Base-uncased-PruneOFA (90% unstruct sparse) | Matched: 81.45 Mismatched: 82.43 |
| natural-language-inference-on-multinli-dev | BERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8) | Matched: 83.47 Mismatched: 84.08 |
| natural-language-inference-on-multinli-dev | DistilBERT-uncased-PruneOFA (85% unstruct sparse) | Matched: 81.35 Mismatched: 82.03 |
| question-answering-on-squad11-dev | DistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8) | EM: 75.62 F1: 83.87 |
| question-answering-on-squad11-dev | BERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8) | EM: 83.22 F1: 90.02 |
| question-answering-on-squad11-dev | BERT-Base-uncased-PruneOFA (85% unstruct sparse) | EM: 81.1 F1: 88.42 |
| question-answering-on-squad11-dev | DistilBERT-uncased-PruneOFA (85% unstruct sparse) | EM: 78.1 F1: 85.82 |
| question-answering-on-squad11-dev | BERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8) | EM: 80.84 F1: 88.24 |
| question-answering-on-squad11-dev | BERT-Large-uncased-PruneOFA (90% unstruct sparse) | EM: 83.35 F1: 90.2 |
| question-answering-on-squad11-dev | DistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8) | EM: 77.03 F1: 85.13 |
| question-answering-on-squad11-dev | DistilBERT-uncased-PruneOFA (90% unstruct sparse) | EM: 76.91 F1: 84.82 |
| question-answering-on-squad11-dev | BERT-Base-uncased-PruneOFA (90% unstruct sparse) | EM: 79.83 F1: 87.25 |