DeepMind, Microsoft, Allen AI & UW Researchers Convert Pretrained Transformers into RNNs, Lowering Memory Cost While Retaining High Accuracy


Powerful transformer models have been widely used in autoregressive generation, where they have advanced the state-of-the-art beyond recurrent neural networks (RNNs). However, because the output words for these models are incrementally predicted conditioned on the prefix, the generation requires quadratic time complexity with regard to sequence length.

As the performance of transformer models increasingly relies on large-scale pretrained transformers, this long sequence generation issue has become increasingly problematic. To address this, a research team from the University of Washington, Microsoft, DeepMind and Allen Institute for AI have developed a method to convert a pretrained transformer into an efficient RNN. Their Transformer-to-RNN (T2R) approach speeds up generation and reduces memory cost.

A number of recent studies have attempted to reduce the overhead of autoregressive transformers. The idea behind these efficient transformer variants is to use recurrent alternatives to approximate standard softmax attention. Like RNNs, these models represent context via a recurrent state with a fixed small size, thereby achieving linear time and constant memory complexity in generation outputs. But these models are not without their limits, as the small state size tends to deteriorate the generation quality.

In the new paper Finetuning Pretrained Transformers into RNNs, researchers propose a conversion approach that improves the balance between efficiency and accuracy. Instead of training a recurrent alternative from scratch, they convert a pretrained transformer into an efficient RNN of linear time and constant space complexity via a swap-then-finetune process.

The swap-then-finetune procedure modifies the attention computation of a pretrained transformer and finetunes the model with the task objective. The researchers first change the exponential similarity function in the attention mechanism to a single-layer MLP feature map, then finetune the MLP and other network parameters.


The team also analyses the computation steps and time complexity for pretrained transformers and T2R models. For pretrained transformers, query, key, and value vectors consume space complexity of O(h), O(Mh), and O(Mh) in every generation step. For T2R, the feature size is much smaller than the input sequence lengths, and the change in the attention stage from O(MNh) to O(Mhk) + O(Nhk), so T2R brings a substantial speedup.

The team conducted extensive experiments on standard benchmarks for language modelling (WikiText-103) and machine translation (WMT14 EN-DE, WMT14 EN-FR and WMT17 ZH-EN).


In language modelling, the T2R 75% model outperformed other transformer models by more than two perplexity points in the pretrain setting and had the lowest training time. On the machine translation test, the T2R model achieved a more than 15 percent speedup over ELU and RFA.

Overall, the results validated that T2R achieves efficient autoregressive generation while retaining high accuracy, proving that large-scale pretrained models can be compressed into efficient inference models that facilitate downstream applications.

The paper Finetuning Pretrained Transformers into RNNs is on arXiv.

Author: Hecate He | Editor: Michael Sarazen

We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.


Show More

Related Articles

Back to top button