How RWKV creates more efficient LLMs

How to improve RNNs to match Transformers in NLP

Devansh
5 min readOct 3, 2024

Non-Transformer LLMs have a lot of promise and they might end up addressing one of LLM’s biggest weaknesses- their high energy costs.

RNN-based networks have a more efficient inference fueled mostly by a cheaper variant of the costly attention mechanism. One such approach is the Open Source “Transformer Killer” RWKV. By implementing some very clever modifications, it gets very close to the “best of both worlds”: the efficient inference of RNNs with the scale of a transformer. Let’s talk about how that becomes possible.

Table 1: Inference complexity comparison with different Transformers. Here T denotes the sequence length, d the feature dimension, c is MEGA’s chunk size of quadratic attention, and s is the size of a local window for AFT.

How RWKV Modernizes the RNNs

RWKV splits the full RNN network into multiple smaller layers, “where each layer’s hidden state can be used independently to compute the next token hidden state for the same layer. This allows for the next token states to be computed partially in parallel, while awaiting the complete calculation of the first hidden state, in a cascading-like pattern…Effectively, this allows the RNN network to operate like a transformer network when rolled out side by side, where it can be trained “like a transformer” and “executed like an RNN” (the best of both worlds).

Figure 7: Cumulative time on text generation for LLMs. Unlike transformers, RWKV exhibits linear scaling.

Before we look into the architectural innovations, it would be good to understand the 4 letters that make up the RWKV name-

  • “R: The Receptance vector acts as the receiver of past information.
  • W : The Weight signifies the positional weight decay vector, a trainable parameter within the model.
  • K: The Key vector performs a role analogous to K in traditional attention mechanisms.
  • V : The Value vector functions similarly to V in conventional attention processes.”

The architecture looks like this-

Figure 2: Elements within an RWKV block (left) and the complete RWKV residual block, equipped with a final head for language modeling (right).

Here is what I think is important:

  • Token Shifting: Instead of just considering the current input and hidden state, we also factor in the last input. More information is retained that way, and is a good mitigation of the loss of context that we experience when we compress long sentences into one hidden state with traditional RNNs.
This is the Token Shifting for time mixing. The same applies for Channel Mixing
  • Channel mixing: Acts kinda like a feed-forward layer in Transformers. It takes a weighted sum of the previous and current value and applies the following non-linearity to it:

This creates a short-term memory block

  • Time mixing is a similar (but more complicated) process. It enables longer-term memory by accounting for both the previous state and learned weights to determine how to combine previous computations and new computations. In the equation below, the yellow highlights give you a weighted sum of all the previous values while the part in red tells you how much to consider the current value.
“To circumvent any potential degradation of W, we introduce a vector U that separately attends to
the current token.”

The time mixing is a very powerful idea, b/c has an interesting advantage over Transformers: unlike Transformers, which have fixed windows, this theoretically can be extended to infinity. Also, notice that none of the time-mixing equations are non-linear (the non-linearity is added after the block). This means that we can parallelize this computation, enabling a much larger scale.

Recurrent networks commonly utilize the output at state t as input at state t+1. This usage is also observed in the autoregressive decoding inference of language models, where each token must be computed before being passed to the next step. RWKV takes advantage of this RNN-like structure, known as time-sequential mode. In this context, RWKV can be conveniently formulated recursively for decoding during inference…

-This is a good observation that allows RWKV to act as a bridge of sorts

One interesting thing that stood out to me was the following quote, “These design elements not only enhance the training dynamics of deep neural networks but also facilitate the stacking of multiple layers, leading to superior performance over conventional RNN models by capturing complex patterns across different levels of abstraction”. This sounds very convolutional. Don’t have anything profound to add here, but I found it worth noting.

All of this results in an RNN that can hold its own against Transformers on various tasks.

Figure 5: Zero-Shot Performance of RWKV on common language modeling evaluation benchmarks. Additional
plots can be found in Appendix J.

This is a snippet from my deepd-dive into the RWKV and how it can completely shake up the AI market. Read more about it here.

I put a lot of work into writing this newsletter. To do so, I rely on you for support. If a few more people choose to become paid subscribers, the Chocolate Milk Cult can continue to provide high-quality and accessible education and opportunities to anyone who needs it. If you think this mission is worth contributing to, please consider a premium subscription. You can do so for less than the cost of a Netflix Subscription (pay what you want here).

I provide various consulting and advisory services. If you‘d like to explore how we can work together, reach out to me through any of my socials over here or reply to this email.

I regularly share mini-updates on what I read on the Microblogging sites X(https://twitter.com/Machine01776819), Threads(https://www.threads.net/@iseethings404), and TikTok(https://www.tiktok.com/@devansh_ai_made_simple)- so follow me there if you’re interested in keeping up with my learnings.

Reach out to me

Use the links below to check out my other content, learn more about tutoring, reach out to me about projects, or just to say hi.

Small Snippets about Tech, AI and Machine Learning over here

AI Newsletter- https://artificialintelligencemadesimple.substack.com/

My grandma’s favorite Tech Newsletter- https://codinginterviewsmadesimple.substack.com/

Check out my other articles on Medium. : https://rb.gy/zn1aiu

My YouTube: https://rb.gy/88iwdd

Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y

My Instagram: https://rb.gy/gmvuy9

My Twitter: https://twitter.com/Machine01776819

--

--

Devansh
Devansh

Written by Devansh

Writing about AI, Math, the Tech Industry and whatever else interests me. Join my cult to gain inner peace and to support my crippling chocolate milk addiction

Responses (1)