From Scaled Dot Product to RWKV v7 Attention: Understanding the Math Behind Attention Mechanisms

Attention is a key component of modern LLMs because it enables the model to weigh the importance of each token relative to others in a sequence. This allows the model to capture dependencies between words, even if they are far apart in the text. One of the core operations behind this is the scaled dot product, a fundamental operation from linear algebra that measures the similarity between vectors.

Scaled Dot Product

In the attention mechanism, we compute the dot product between the query and key vectors for each token. The result is then scaled by dividing by the square root of the dimension of these vectors, ensuring that the values do not become too large and affect the softmax function. This process helps the model to attend to relevant tokens while ignoring irrelevant ones, allowing it to focus on meaningful relationships within the data.

Attention(Q,K,V) = softmax(\frac{QK^T}{√dim}) Time \space Complexity = O(N^2d) Space \space Complexity = O(N^2+Nd)

MultiHead Attention

In MultiHead Attention, the model splits the input into multiple attention heads. Each head computes attention independently, allowing the model to focus on different parts of the input simultaneously. The results are then combined, helping the model capture diverse patterns and relationships within the data.

MultiHead(Q,k,V) = Concat(head_1,head_2,....,head_n) head_i = Attention(QW_i,KW_i,VW_i)

Linear Attention

Linear Attention improves the efficiency of traditional attention mechanisms by reducing the computational complexity. Instead of computing attention for all pairs of tokens (which scales quadratically with sequence length), Linear Attention approximates the process, allowing it to scale linearly. This is done by breaking down the attention calculation into kernel functions, which are much faster to compute.

We start from the self attention formula:

Attention(Q,K,V) = softmax(\frac{QK^T}{√dim}) (1) In this equation the softmax function is applies rowwise to QKᵗ. We can rewrite a generatilzed attention equation where the matrix with i returns the i-th row. V^T_i = \frac{∑_j^N sim(Q_i,K_j)V_j}{∑_j^N sim(Q_i,K_j)} (2)

Where the equation 1 is equivalent to equation 2 if we substitute the similarity function with :

sim(q,k) = exp( \frac{q^tk}{√dim})

Now let’s start building the Linearized Attention. We now need to impose to sim() function,in order for the equation 2 to define an attention function,to be non-negative. We ubckyde all kernels :

k(x,y): \R^{2×F} \to \R_+

Given a kernel with a feature representation Φ(x) we can rewrite equation 2 :

V^T_i = \frac{∑_j^N Φ(Q_i)^T Φ(K_j)V_j}{∑_j^N Φ(Q_i)^T Φ(K_j)} (3) Using the associative property of matrix multiplication we get : V^T_i = \frac{Φ(Q_i)^T∑_j^N Φ(K_j)V_j^T}{Φ(Q_i)^T∑_j^N Φ(K_j)} (4) We can rewrite it like this : (Φ(Q)Φ(K)^T)V = Φ(Q)(Φ(K)^TV)(5) where we applied Φ() rowwise to the Q and K matrices

Causal Masking for Linearized Attention

The transformer architecture can be used to efficiently train autoregressive models by masking the attention computation such that the i-th position can only be influenced by j if j is smaller than i. And we get this formula

V^T_i = \frac{Φ(Q_i)^T∑_j^i Φ(K_j)V_j}{Φ(Q_i)^T∑_j^i Φ(K_j)} (6) We now intrduce S and Z as follows: S_i = ∑_{j=1}^i Φ(K_j)V_j^T Z_i = ∑_{j=1}^i Φ(K_j) V_i^T = \frac{Φ(Q_i)^TS_i}{Φ(Q_i)^TZ_i} And we can computer Sᵢ from Sᵢ₋₁ and Zᵢ from Zᵢ₋₁ in constant time hence making the computational complexity of linear transformers with casual masking with respect to the sequence length Time \space Complexity = O(Nd^2) Space \space Complexity = O(Nd+d^2)

Attention Free Transformer (AFT)

We now we will define the Attention Free Transformer which is a plugin replacement of MHa without the need of changing other architectural aspects of Transformers.

Y = f(X) Y_t = σ_q(Q_t)⊙\frac{∑_i^T exp(K_i + wₜ,ᵢ)⊙V_i}{∑_i^N exp(K_i + wₜ,ᵢ)} *NOTE : I changed from the original paper the t' to i for better understanding *

For each target position t, AFT performs a weighted average of values and the result is comvined with the query with element-wise multiplication. As we can see, we have a new variable in the game, the W matrix which is a learned pair-wise position bias

AFT-Full - first equation

Time \space Complexity = O(Td^2) Space \space Complexity = O(Td)

AFT-Local - We see that the relative attention maps demonstrate strong local patterns , especially in the lower layers. This motivates a variant of AFT, AFT Local where we apply the learned set of relative postion biases locally:

wₜ,ᵢ = wₜ,ᵢ , |t-i| <s , else 0. Here we have s<=T is a local window.(T is the sequence length) Time \space Complexity = O(Tsd) Space \space Complexity = O(Td)

AFT-Simple - An extreme form of AFT local when s = 0 , so no position bias is learned . This gives rise to an extremely simple version of AFT

Y_t = σ_q(Q_t)⊙ ∑_{i=1}^T softmax(K)⊙V AFT Simple gets rid of the need for dot products operations. Time \space Complexity = O(Td) Space \space Complexity = O(Td)

RWKV v1

Begining NOTE: I will go deeper in the RWKV model in this section as it is important to understand the time-mixing and the channel-mixing blocks in order to get a strong understanding of the mechanism used in this architecture. More details can be found on the official paper :D

RWKV is a type of recurrent neural network (RNN) that aims to combine the efficiency of RNNs with the performance of Transformer models. It’s designed to be highly scalable, memory-efficient, and capable of handling long sequences while avoiding some of the drawbacks of traditional RNNs, such as vanishing gradients.

RWKV is defined by 4 fundamental elements:

R - Receptance vector - receiver of past information
W - Weight vector- positional weight decay ,
K - Key vector - analogous to K in traditional atn
V - Value vector - analogous to v in traditional atn

Architecture

The model is composed of stacked residual blocks. Each block is formed iwth time-mixing and channel-mixing sub-blocks

The linear projection vector involved in computation are produced by linear interpolation between current and previous timestep inputs.

Time-Mixing:

r_t = W_r (μ_r⊙x_t + (1-μ_r)⊙x_{t-1}) k_t = W_k (μ_k⊙x_t + (1-μ_k)⊙x_{t-1}) v_t = W_v (μ_v⊙x_t + (1-μ_v)⊙x_{t-1})

Channel-Mixing:

r_t^` = W_r^` (μ_r^`⊙x_t + (1-μ_r^`)⊙x_{t-1}) k_t^` = W_k^` (μ_k^`⊙x_t + (1-μ_k^`)⊙x_{t-1})

WKV Operator

The WKV operator (Weighted Key-Value operator) is the model parallel method used in Attention Free Transformer AFT. However, in AFT the W is a pairwise matrix, RWKV architecture treats it as a channel-wise vector that is modified by relative position.

wkv_t =\frac{∑_iᵗ⁻¹ exp(-(t-1-i)w+k_i)⊙v_i + exp(u+k_t)⊙v_t}{∑_iᵗ⁻¹ exp(-(t-1-i)w+k_i)+ exp(u+k_t)}

Ouputing gate is implemented in both time-mixing and channel-mixing blocks using the sigmoid of receptanceσ(r)

o_t = W_o(σ(r_t)⊙wkv_t) o_t^` = σ(r_t^`)⊙(W_o^` max(k^`_t,0)^2)

And now let’s see the complexity :D

Time \space Complexity = O(Td) Space \space Complexity = O(d)

Note this blos is still under construction , please be patient , the code implementation will come soooooon :D