Retentive Network: A Successor to Transformer for Large Language Models
Abstract
In this work, we propose Retentive Network (RetNet) as a foundation architecture for large language models, simultaneously achieving training parallelism, low-cost inference, and good performance. We theoretically derive the connection between recurrence and attention. Then we propose the retention mechanism for sequence modeling, which supports three computation paradigms, i.e., parallel, recurrent, and chunkwise recurrent. Specifically, the parallel representation allows for training parallelism. The recurrent representation enables low-cost O(1) inference, which improves decoding throughput, latency, and GPU memory without sacrificing performance. The chunkwise recurrent representation facilitates efficient long-sequence modeling with linear complexity, where each chunk is encoded parallelly while recurrently summarizing the chunks. Experimental results on language modeling show that RetNet achieves favorable scaling results, parallel training, low-cost deployment, and efficient inference. The intriguing properties make RetNet a strong successor to Transformer for large language models. Code will be available at https://aka.ms/retnet.
Community
I don't see why there are three representations of retention and how they fit together? And whether Retension is equivalent to the expressive power of Attention, I didn't see the proof or explanation in the article, is there any kind person who can teach me?
Table 1 - I don't quite understand the O(N) "inference cost" of transformer, I though transformer time complexity is simply O(N^2). Why is the "inference cost" O(N)?
Is it "inference cost per token"?
If inference is O(1) as suggested, does that mean we can remove tokenizers
altogether and run byte-level models direct ? That would be like removing SIFT and SURF to introduce Convolutions in convnet...
I loved seeing that they trained it on AMD GPUs!
Paper summary in video format: https://www.youtube.com/watch?v=JaIL1VAEwZ8
Table 1 - I don't quite understand the O(N) "inference cost" of transformer, I though transformer time complexity is simply O(N^2). Why is the "inference cost" O(N)?
Is it "inference cost per token"?
Yes should be per token.
Did I miss something about the GroupNorm? It seems to make the model no longer causal?
For example, in the parallel mode, GroupNorm is applied to the output, which is of shape (batch_size, num_head, seq_len, v_dim)
. GroupNorm will normalize each batch sample, each head separately, by computing the mean/std over the seq_len
and v_dim
dimensions. Wouldn't that leak the stats information from the later time steps to eariler time steps, making the model not causal anymore?
Unofficial implementation (repo): https://github.com/Jamie-Stirling/RetNet
Did I miss something about the GroupNorm? It seems to make the model no longer causal?
For example, in the parallel mode, GroupNorm is applied to the output, which is of shape
(batch_size, num_head, seq_len, v_dim)
. GroupNorm will normalize each batch sample, each head separately, by computing the mean/std over theseq_len
andv_dim
dimensions. Wouldn't that leak the stats information from the later time steps to eariler time steps, making the model not causal anymore?
I don't think the groupnorm leaks information in this way. The output isn't of the shape you specified, it's just (batch, seq_len, model_dim) since the head outputs have been concatenated at this point.
A complete (unofficial) implementation of RetNet with all forward types (and without complex numbers) https://github.com/syncdoth/RetNet
If I understood things correctly, the RetNet is still just a recurrent network, passing hiddent state between timestamp, und thus still have the size of the hidden state as a bottleneck, right? Unlike the transformer attention, where every token attends to all others, which is in principle O(n^2), but probably can be reduced through convolution calculated in frequency space with FFN to O(n log(n)). Similar to what people do for state space models (hyena, hippos, etc). I assume there cannot be a faster way for doing n-to-n attention than that.
But in theory RetNet shouldn't be good at long-range dependency tasks, like associative recal used for testing hyena (https://arxiv.org/abs/2302.10866).
Interesting just how important those long dependencies are actually, given the success of LLMs with shorter contexts?
Proposes Retentive Network (RetNet) with a retention mechanism (recurrence with attention); supports parallel, recurrent (constant time), and chunk-wise recurrent (linear time complexity to input length) paradigms. Scalable and efficient alternative to transformers for LLMs. Series of identical blocks, with the each containing multi-scale retention (MSR, instead of MHA) and Feed Forward (FF/FFN) layer. Retention mechanism: sequence projected to 1D sequence, process sequence through each layer using d-dim intermediate states; attention over previous state, key projection over (1D) input, query projection to output state (can be simplified), query and key are (input) content aware (derived from unprojected input), represent attention matrix (of state) in a diagonal form (decomposition) and simplifying the product (for output) yields xPos (relative position embeddings for transformer) - parallelisable formulation. Chunkwise recurrent representation (hybrid of parallel and recurrent): input broken into chunks with fixed length; parallel representation within each chunk, recurrent representation for cross-chunk. Gated Multiscale Retention (MSR): concat output of multiple heads, use GroupNorm, use Swish gate on result, then mix output (through learnable weights), head-wise normalisation (different variance statistic per head). Add normalisation for attention, causal masking, and retention scores (final combination) - all in parallel retention. Single RetNet block: input passed through LayerNorm, then MSR, then residual with input; give this to LN and FFN with residual (FFN parameterized by two linear layers with GeLU activation in middle). RetNet has lower validation PPL (perplexity distance/loss) with scale (compared to transformers). Language modeling task content from The Pile, C4, and The Stack; trained with TorchScale on 512 AMD MI200 GPUs. Lesser memory and higher throughput than transformers (even with FlashAttention). Better in-domain and out-of-domain performance than transformer variants (RWKV, H3, Hyena, and Linear Transformer). Hyperparameters (of RetNet models) in appendix. From Tsinghua, Microsoft.
Links: GitHub (part of microsoft/unilm, trained using microsoft/torchscale), Website
The authors report memory consumption and throughput for models from 1.3B to 13B parameters but only 1.3B to 6.7B for perplexity. I am curious why they didn't report perplexity for 13B?
I don't see why there are three representations of retention and how they fit together? And whether Retension is equivalent to the expressive power of Attention, I didn't see the proof or explanation in the article, is there any kind person who can teach me?
If I understood correctly, it's something like this:
- Parallel form: for fast training
- Recurrent form: for fast inference
- Chunkwise recurrent form: for fast training on very long sequences
The three forms are factorized such that you can switch between the various forms with the same weights. I don't recall seeing any mathematical proofs that retention is equivalent to attention, just empirical performance comparisons. Just finished reading the paper myself, so take all of this with a grain of π§!
The authors report memory consumption and throughput for models from 1.3B to 13B parameters but only 1.3B to 6.7B for perplexity. I am curious why they didn't report perplexity for 13B?
The reason is quite simple. Training a 13B model requires more computing resources. However, measuring the memory and throughput don't.
I don't see why there are three representations of retention and how they fit together? And whether Retension is equivalent to the expressive power of Attention, I didn't see the proof or explanation in the article, is there any kind person who can teach me?
If I understood correctly, it's something like this:
- Parallel form: for fast training
- Recurrent form: for fast inference
- Chunkwise recurrent form: for fast training on very long sequences
The three forms are factorized such that you can switch between the various forms with the same weights. I don't recall seeing any mathematical proofs that retention is equivalent to attention, just empirical performance comparisons. Just finished reading the paper myself, so take all of this with a grain of π§!
The math here is not complicated. If you spend some time, you can follow our Equation from (1) to (7). As for the expressive power, I'd like to see your theoretical proof that Attention is best and others don't.
@sunyt32 I've found that the torchscale repo now has the official implementation of RetNet; is there any plan for huggingface integration in the future? Also, will the pretrained weights be made public? Thanks!
The checkpoint in the paper is not suitable as a released version. A high-quality pretrained weight is our incoming plan. At that time, the huggingface integration will be ready.
Hi, I tried to implement the basic model here: https://github.com/prateekstark/retnet for learning since I was a bit skepticle. I will update the repository more frequently now. I would like to invite other collaborators who would like to improve upon the code for learning purposes.
As for the expressive power, I'd like to see your theoretical proof that Attention is best and others don't.
Obviously, an RNN can only pass long-range dependencies through the hidden state. So a small recurrent hidden size, should greatly limit, the amount of locations one can attend to. And therefore a task like associative recall, or a slightly more complex task, like "give me all values for a given key", shouldn't be solvable (by any RNN not only RetNet) once a certain sequence length (for a given hidden size) is reached; while the same task should in principle be solvable by a transformer with O(n^2) attention independently of the length (for as long as the O(n^2) stays computable).
The real question is, if LLMs do really need long-range dependencies or not. Until now, they were quite successful even with short context lengths.
By all means, the RetNet is very cool paper. (as all the S4, hippo, hyeanas, RWKV, etc are)
I didn't get it, why for RWKV the Training Parallelization column in Table 1 says it's β? RWKV has better complexity of O(dn) comparing to O(dn(b + h)) of RetNet
As for the expressive power, I'd like to see your theoretical proof that Attention is best and others don't.
Obviously, an RNN can only pass long-range dependencies through the hidden state. So a small recurrent hidden size, should greatly limit, the amount of locations one can attend to. And therefore a task like associative recall, or a slightly more complex task, like "give me all values for a given key", shouldn't be solvable (by any RNN not only RetNet) once a certain sequence length (for a given hidden size) is reached; while the same task should in principle be solvable by a transformer with O(n^2) attention independently of the length (for as long as the O(n^2) stays computable).
The real question is, if LLMs do really need long-range dependencies or not. Until now, they were quite successful even with short context lengths.
By all means, the RetNet is very cool paper. (as all the S4, hippo, hyeanas, RWKV, etc are)
long-range dependency is not just a question about recurrent hidden size. In fact, it's also bounded by embedding size. You can see the previous Linear Transformer works well on some long-range datasets. My answer is, the upper bound always exists, but current applications can't reach it for now.
I didn't get it, why for RWKV the Training Parallelization column in Table 1 says it's β? RWKV has better complexity of O(dn) comparing to O(dn(b + h)) of RetNet
The previous RWKV implementation uses recurrent fashion, so when the training length is very long it can't utilize full GPU abilities. The training parallelization has nothing to do with complexity. I personally believe RetNet has the best tradeoff between complexity and performance where smaller context size will seriously hurt the performance.
I think there is a mistake with the pseudocode for RecurrentRetention
:
The shape of q, k, v
should be bsz β num_head β qkv_dim
, not bsz β num_head β len β qkv_dim
. Also, I'm not sure why the authors adopt a different form of multiplication in the pseudocode vs. in Equation (6). They're essentially identical.
Also, the intra-chunk row subscript i of ΞΎ_{ij} can be confused with the chunk index i, so I would recommend using k instead. The last sentence (notation of [i]) is also problematic as i starts with 0. It can be removed as Q[i], K[i], V[i] already define.
is there any plans to train retnet, or do i have to do it myself?
@sunyt32 I've found that the torchscale repo now has the official implementation of RetNet; is there any plan for huggingface integration in the future? Also, will the pretrained weights be made public? Thanks!
The checkpoint in the paper is not suitable as a released version. A high-quality pretrained weight is our incoming plan. At that time, the huggingface integration will be ready.
I am already experimenting with retnet, the efficiency is pretty cool.
Looking forward for the release of weights
this paper is really exciting. I hope to see some result of retnet that I can use it for tasks.
@sunyt32 I've found that the torchscale repo now has the official implementation of RetNet; is there any plan for huggingface integration in the future? Also, will the pretrained weights be made public? Thanks!
The checkpoint in the paper is not suitable as a released version. A high-quality pretrained weight is our incoming plan. At that time, the huggingface integration will be ready.
May I ask is there any expected time of release? Very interested, planned to train a RetNet on my own, much more convenient if a public checkpoint is available.
Is the Huggingface integration still a work in progress? No update has been mentioned here. @sunyt32
hi, how is everything going? Any pretrained models available?
In Eq.(1), what happened to s_0? The expansion should have another term Q_n A^{n+1} s_0, right?
Models citing this paper 11
Browse 11 models citing this paperDatasets citing this paper 0
No dataset linking this paper