HyperAIHyperAI

Command Palette

Search for a command to run...

3 months ago

Finetuning Pretrained Transformers into RNNs

Jungo Kasai Hao Peng Yizhe Zhang Dani Yogatama Gabriel Ilharco Nikolaos Pappas Yi Mao Weizhu Chen Noah A. Smith

Finetuning Pretrained Transformers into RNNs

Abstract

Transformers have outperformed recurrent neural networks (RNNs) in natural language generation. But this comes with a significant computational cost, as the attention mechanism's complexity scales quadratically with sequence length. Efficient transformer variants have received increasing interest in recent works. Among them, a linear-complexity recurrent variant has proven well suited for autoregressive generation. It approximates the softmax attention with randomized or heuristic feature maps, but can be difficult to train and may yield suboptimal accuracy. This work aims to convert a pretrained transformer into its efficient recurrent counterpart, improving efficiency while maintaining accuracy. Specifically, we propose a swap-then-finetune procedure: in an off-the-shelf pretrained transformer, we replace the softmax attention with its linear-complexity recurrent alternative and then finetune. With a learned feature map, our approach provides an improved tradeoff between efficiency and accuracy over the standard transformer and other recurrent variants. We also show that the finetuning process has lower training cost relative to training these recurrent variants from scratch. As many models for natural language tasks are increasingly dependent on large-scale pretrained transformers, this work presents a viable approach to improving inference efficiency without repeating the expensive pretraining process.

Code Repositories

hazyresearch/lolcats
pytorch
Mentioned in GitHub
yashbonde/RNN-sim
pytorch
Mentioned in GitHub

Benchmarks

BenchmarkMethodologyMetrics
language-modelling-on-wikitext-103T2R + Pretrain
Test perplexity: 19.6
Validation perplexity: 19
machine-translation-on-wmt2014-english-frenchT2R + Pretrain
BLEU score: 42.1
Hardware Burden:
Operations per network pass:
machine-translation-on-wmt2014-english-germanT2R + Pretrain
BLEU score: 28.7
Hardware Burden:
Operations per network pass:
machine-translation-on-wmt2017-chineseT2R + Pretrain
BLEU: 23.8

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing
Get Started

Hyper Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp
Finetuning Pretrained Transformers into RNNs | Papers | HyperAI