Attention is a key component of modern LLMs because it enables the model to weigh the importance of each token relative to others in a sequence. This allows the model to capture dependencies between words, even if they are far apart in the text. One of the core operations behind this is the scaled dot product, a fundamental operation from linear algebra that measures the similarity between vectors.
In the attention mechanism, we compute the dot product between the query and key vectors for each token. The result is then scaled by dividing by the square root of the dimension of these vectors, ensuring that the values do not become too large and affect the softmax function. This process helps the model to attend to relevant tokens while ignoring irrelevant ones, allowing it to focus on meaningful relationships within the data.
In MultiHead Attention, the model splits the input into multiple attention heads. Each head computes attention independently, allowing the model to focus on different parts of the input simultaneously. The results are then combined, helping the model capture diverse patterns and relationships within the data.
Linear Attention improves the efficiency of traditional attention mechanisms by reducing the computational complexity. Instead of computing attention for all pairs of tokens (which scales quadratically with sequence length), Linear Attention approximates the process, allowing it to scale linearly. This is done by breaking down the attention calculation into kernel functions, which are much faster to compute.
We start from the self attention formula:
In this equation the softmax function is applies rowwise to QKᵗ. We can rewrite a generatilzed attention equation where the matrix with i returns the i-th row.Where the equation 1 is equivalent to equation 2 if we substitute the similarity function with :
Now let’s start building the Linearized Attention. We now need to impose to sim() function,in order for the equation 2 to define an attention function,to be non-negative. We ubckyde all kernels :
Given a kernel with a feature representation Φ(x) we can rewrite equation 2 :
Using the associative property of matrix multiplication we get : We can rewrite it like this : where we applied Φ() rowwise to the Q and K matricesThe transformer architecture can be used to efficiently train autoregressive models by masking the attention computation such that the i-th position can only be influenced by j if j is smaller than i. And we get this formula
We now intrduce S and Z as follows: And we can computer Sᵢ from Sᵢ₋₁ and Zᵢ from Zᵢ₋₁ in constant time hence making the computational complexity of linear transformers with casual masking with respect to the sequence lengthWe now we will define the Attention Free Transformer which is a plugin replacement of MHa without the need of changing other architectural aspects of Transformers.
*NOTE : I changed from the original paper the t' to i for better understanding *For each target position t, AFT performs a weighted average of values and the result is comvined with the query with element-wise multiplication. As we can see, we have a new variable in the game, the W matrix which is a learned pair-wise position bias
AFT-Full - first equation
AFT-Local - We see that the relative attention maps demonstrate strong local patterns , especially in the lower layers. This motivates a variant of AFT, AFT Local where we apply the learned set of relative postion biases locally:
, else 0. Here we have s<=T is a local window.(T is the sequence length)AFT-Simple - An extreme form of AFT local when s = 0 , so no position bias is learned . This gives rise to an extremely simple version of AFT
AFT Simple gets rid of the need for dot products operations.Begining NOTE: I will go deeper in the RWKV model in this section as it is important to understand the time-mixing and the channel-mixing blocks in order to get a strong understanding of the mechanism used in this architecture. More details can be found on the official paper :D
RWKV is a type of recurrent neural network (RNN) that aims to combine the efficiency of RNNs with the performance of Transformer models. It’s designed to be highly scalable, memory-efficient, and capable of handling long sequences while avoiding some of the drawbacks of traditional RNNs, such as vanishing gradients.
The model is composed of stacked residual blocks. Each block is formed iwth time-mixing and channel-mixing sub-blocks
The linear projection vector involved in computation are produced by linear interpolation between current and previous timestep inputs.
The WKV operator (Weighted Key-Value operator) is the model parallel method used in Attention Free Transformer AFT. However, in AFT the W is a pairwise matrix, RWKV architecture treats it as a channel-wise vector that is modified by relative position.
Ouputing gate is implemented in both time-mixing and channel-mixing blocks using the sigmoid of receptanceσ(r)
And now let’s see the complexity :D
Note this blos is still under construction , please be patient , the code implementation will come soooooon :D