RWKV Models: Exploring v5.2, v6, and v7 with Code Snippets

In this blog I will discuss about the RWKV models (v5.2, v6 ,v7). I will use code snippets from https://github.com/SmerkyG/RWKV_Explained

We will build the model from the main class and add other classes when we encounter them.

RWKV v5.2

RWKV class

# Main class of the model
class RWKV(torch.nn.Module):
    def __init__(self, cfg:Config = Config()):
        super().__init__()
        self.cfg = cfg # Configuration 
        self.embed = nn.Parameter(torch.empty(cfg.vocab_size, cfg.d_model))
        self.embed_norm = nn.LayerNorm(cfg.d_model)
        # Embeds are made by creating an matrix of size [vocab_size,d_model] # Which is exactly a linear layer
        # After that we normalize it 

        self.layers = nn.ModuleList([Layer(cfg, layer_id) for layer_id in range(cfg.n_layers)])
        # Creates ModuleList of Layers that are created by the upcoming function 
        # We also add the layer_id.

        self.lm_head_norm = nn.LayerNorm(cfg.d_model)
        # Normalizing the logits of the model
        self.lm_head_unembed = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        # After that we multiply the logits with a linearlayer so that we can have the probabilities over all vocab values

    # input tensor dimensions:
    #   x (B,T) -> [Batch_size, timestamp] 
    #   example : [2,20] - > 2 batches and each has 20 timestamp (sentence of 20 tokens) 
    def forward(self, x : Tensor, s : list[LayerState]|None = None):
        # calculate embeddings for each incoming token, then normalize them
        # see https://github.com/BlinkDL/SmallInitEmb for details on why we do this normalization
        # if you look at some literature on pre-genereated embeddings, you'll see that they are 
        #  often ideally considered to become unit length vectors around the hypersphere 
        #  so starting it as small noise while using a normalized version instantly causes this layout, 
        #  allowing them to initially move rapidly around the surface of that hypersphere, and later slowly
       
        x = self.embed_norm(nn.functional.embedding(x, self.embed))
        # Embedding the data

        s = s or [LayerState(x, self.cfg) for _ in range(self.cfg.n_layers)]
        # We create the LayerState which will be saved in a list for each Number of layers.
        # We will look after that at the LayerState function to analyze it

        # run each layer in succession, passing in the RNN state for that layer
        for layer_id, block in enumerate(self.layers):  # run each rwkv block
            x, s[layer_id] = block(x, s[layer_id]) 

        # normalize the output
        x = self.lm_head_norm(x)

        # unembed back to dictionary indices
        x = self.lm_head_unembed(x)

        return x, s

Now let’s look first at the LayerState class to understand the structure of our model. We need to understand how the data is processed comparing it to a “GPT” version. Take in mind that this architecture is in essence a RNN (maybe it will help to understand it better)

LayerState

class LayerState:
    # the recurrent neural network (RNN) state for a layer of RWKV5.2 
    def __init__(self, x, cfg:Config):
        # input tensor dimensions:
        # x (B,T)
        # C = d_model -> dimension of a model size. Well think about the size for the Matrix Q/K/V in GPT :D
        # H = n_heads -> number of 'attention' heads
        # K = d_model//n_heads -> will result the embedding size of each head :D
        B, T, C, H, K = x.size(0), x.size(1), cfg.d_model, cfg.n_heads, cfg.d_model // cfg.n_heads
        V = K

        # a (B,C) size tensor representing latest time mixer token embedding processed
        # responsible for capturing relationship between tokens
        self.time_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)

        # an (B,H,K,V) size tensor representing a decaying token embedding memory for each head, where H=number_of_heads, K=key_dim_per_head, V=value_dim_per_head 
        self.kv_state = torch.zeros(B,H,K,V,dtype=torch.float32,device=x.device)
        
        # a (B,C) size tensor representing latest channel mixer token embedding processed
        self.channel_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)

Now lets go deeper in the model

Layer class

Thiss class will calculate time-mixing and channel-mixing. ~ add picture ~

class Layer(nn.Module):
    def __init__(self, cfg:Config, layer_id:int):
        super().__init__()
        self.time_mixer = TimeMixer(cfg, layer_id) # pass the configuration + layer_id
        self.channel_mixer = ChannelMixer(cfg, layer_id) # same as above

    def forward(self, x : Tensor, s : LayerState):
        # X will have the shape : [batch_size,timestamp,dimension_embed]
        x, s.time_mixer_x_state, s.kv_state = self.time_mixer(x, s.time_mixer_x_state, s.kv_state)
        # We update the x, The time-mixer-state and kv-state
        # We will introduce them down

        x, s.channel_mixer_x_state = self.channel_mixer(x, s.channel_mixer_x_state)
        # We update our x and channel-mixer-state
        
        # return x and the state
        return x, s

Let’s start with the time_mixer class

Time-Mixer class

□_t = lerp(x_t,x_{t-1})*W_□, □∈(r,k,v,g) We use the square as an iterable in the list. Lerp is defined as : lerp_□(a,b) = a + (b-a) ⊙ μ_□ Where we have μ_□ ∈R^D learnable parameter w = exp(-exp(w)) The values will sit between (0,1) wkv_t = diag(u) * k_t^T*v_t + ∑_t^{t-1} diag(w)^{t-1-i}*k_i^T *v_i Lets say we have diag(u) a 3x3 matrix and A a full matrix 3x3. When we multiply diag(u) with A we will have the matrix A scaled on each row with [u1 u2 u3] ( row-wise scalling). The diag(w) is a multiplicative decay to ensure that each time step has a more influence in the sum. So the first element represents the contribution to current timestamp and the summation is the contribution from previous time steps with time decay factor o_t = concat(SiLU(g_t)⊙LayerNorm(r_t*wkv_t))*W_0 To begin with the inside . We do the product r_t*wkv_t can be viewed as a gate that calculates the influence of calculated attention before we pass to the next layer.We normalize it. After that we do the product with the SiLU(g_t) . we concatenate the result and do a weighted projection
class TimeMixer(nn.Module):
    def __init__(self, cfg:Config, layer_id:int):
        super().__init__()
        self.cfg = cfg # saving the config

        d_model = cfg.d_model # dimension of the embeding
        d_head = d_model // cfg.n_heads # dimension of each head

        self.prenorm = nn.LayerNorm(d_model) # Pre Normalisation

        self.tokenshift_receptance = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
        self.tokenshift_key = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
        self.tokenshift_value = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
        self.tokenshift_gate = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]

        self.receptance = nn.Linear(d_model, d_model, bias=False) #
        self.key = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=False)
        self.gate = nn.Linear(d_model, d_model, bias=False)
        self.output = nn.Linear(d_model, d_model, bias=False)
        # Weighted Matrices calculated 

        # per-channel boost for current embedding
        self.bonus = nn.Parameter(torch.ones(cfg.n_heads, d_head))

        # per-channel decay multipliers applied to kv_state at each timestep
        self.decay = nn.Parameter(torch.ones(cfg.n_heads, d_head))
        
        # Norm
        self.group_norm = nn.GroupNorm(cfg.n_heads, d_model, eps=64e-5)

Lets dive down in the forward pass

    def forward(self, hidden_state_in : Tensor, x_state : Tensor, kv_state : Tensor):
        # x (B,T,C), x_state (B,C), kv_state (B,H,K,V)
        x = self.prenorm(hidden_state_in) # prenorm the input
        x_state_out = x[:, -1] # last token embedding saved

        B, T, C, H, K = x.size(0), x.size(1), self.cfg.d_model, self.cfg.n_heads, self.cfg.d_model // self.cfg.n_heads

        # we want the token embeddings shifted over by one towards the past
        # to get this, we take the last token embedding processed and append all but one of the current token embeddings to it
        # (the last token embedding processed is what's stored in the x_state)
        x_shifted_one_to_the_past = torch.cat((x_state.unsqueeze(-2), x[:,:-1]), dim=1)
        # x_state.unsqueeze(-1) will make its size (Batch_size,1,State)
        # so it will be one timestamp
        # we concatenate with x[:,:-1] which will be all tokens without the last one

        # token shift the incoming token embeddings for the receptance, key, value, and gate
        # PLEASE NOTE THAT THE DIRECTION OF THE LERP CHANGED IN RWKV-6
        
        x_receptance = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_receptance)
        x_key = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_key)
        x_value = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_value)
        x_gate = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_gate)
        # runt the linear interpolation:
        
        # torch.lerp(input,end,weight):
        # out[i] = start[i] + weight[i]*(end[i]-start[i])

        # So for example in our case lets say the following:
        # x = [1,2,3,4,5]
        # x_shifted = [10,2,3,4]
        # we will iterate over them and interpolate 
        # self.tokenshift = [1,1,d_model]
        # we have x of the form : [batch_size,Tokens,d_model]
        # So we interpolate only on the channels/embeddings
        # This is for key,value,receptance,gate

                                           
        # the extra dimensions are being added here to enable matrix multiplications per timestep
        r = self.receptance(x_receptance).view(B,T,H,1,K) # BTH1K
        k = self.key(x_key).view(B,T,H,K,1) # BTHK1 
        v = self.value(x_value).view(B,T,H,1,K) # BTH1K
        gate = self.gate(x_gate) # BTC

        # this forces the decays to end up in the range 0...1 using a nicely differentiable function
        decay = torch.exp(-torch.exp(self.decay.float())) # HK

        out = torch.empty(B, T, H, K, dtype=x.dtype, device=x.device)
        # output = [batch_size,timestamp,n_heads,head_dim]
        for t in range(T):
            out[:,t], kv_state = TimeMixer.single_timestep(r[:,t], k[:,t], v[:,t], 
                                                           self.bonus,
                                                           decay,
                                                           kv_state)
        # We will look into the single_timestep processing function in a second
        # So we iterate with t from [0->nr_timesteps]
        # out[:,t] -> out[all_batches,specific_timestep]
        # kv_state is updated
        # single_timestep(r[:,t],k[:,t],v[:,t],self.bonus,decay,kv_state)

        # r[:,t] -> [all_batches,specific_timestep,n_heads,1,head_dim]
        # example : r[:,3] = [1,3,16,1,64]


        # apply group normalization to each head and recombine the heads
        out = self.group_norm(out.view(B*T, C)).view(B, T, C) # BTC

        # apply silu gate to the output
        out = out * nn.functional.silu(gate) # BTC

        # project the output
        out = self.output(out) # BTC

        return hidden_state_in + out, x_state_out, kv_state

Now lets see the actual time_step processing function

    @staticmethod
    def single_timestep(r, k, v, u, w, kv_state): 
        original_dtype = r.dtype

        B, H, K, V = kv_state.shape
        # get the shapes

        # transform inputs from BHK and put everything in float format for higher precision
        r = r.float().view(B, H, 1, K)
        w = w.float().view(1, H, K, 1)
        k = k.float().view(B, H, K, 1)
        v = v.float().view(B, H, 1, V)
        u = u.float().view(1, H, K, 1)

        kv = k @ v # BHK1 @ BH1V = BHKV
        # matrix multiplication between keys and values
        # K: [batch,Nr_heads,Head_dim,1]
        # V: [batch,Nr_heads,1,Head_dim]

        # start with the existing kv state
        y = kv_state        # BHKV

        # apply the u boost to the current k @ v and add it to that
        y = y + kv * u # BHKV + BHKV * 1HK1 = BHKV

        # apply receptance to that whole result
        out = r @ y         # BH1K @ BHKV = BH1V
        
        # remove an extra useless dimension from the output
        out = out.squeeze(-2).to(original_dtype) # BHV

        # finally, decay the kv state and add in the latest k @ v
        kv_state = kv_state * w    # BHKV * BHK1 = BHKV
        kv_state = kv_state + kv   # BHKV + BHKV = BHKV

        return out, kv_state # BHV, BHKV

Channel-Mixer class

Channel Mixing formulas ar not that hard and are very easy to follow. We do a Lerp , weighted relu squared and we use the sigmoid of the function:

r_t^, = lerp_{r^,}(x_t^,,x_{t-1}^,) W_{r^,} k_t^, = lerp_{k^,}(x_t^,,x_{t-1}^,) W_{k^,} v_t^, = ReLU(k_t^,)^2W_{v^,} o_t^, = σ(r_t^,)⊙ v_t^,
class ChannelMixer(nn.Module):
    def __init__(self, cfg:Config, layer_id:int):
        super().__init__()
        self.cfg = cfg
        self.prenorm = nn.LayerNorm(cfg.d_model) #
        self.tokenshift_in = nn.Parameter(torch.empty(1, 1, cfg.d_model))
        self.tokenshift_gate = nn.Parameter(torch.empty(1, 1, cfg.d_model))
        
        d_ffn = cfg.d_ffn or int(cfg.d_model * 3.5)//32*32
        # FFN multiplyer

        self.W_in = nn.Linear(cfg.d_model, d_ffn, bias=False)
        self.W_out = nn.Linear(d_ffn, cfg.d_model, bias=False)
        self.gate = nn.Linear(cfg.d_model, cfg.d_model, bias=False)

    def forward(self, hidden_state_in : Tensor, x_state : Tensor): # x (B,T,C), x_state (B,C)
        x = self.prenorm(hidden_state_in)
        x_state_out = x[:, -1]
        # token shift the incoming token embeddings for both the input projection and gate

        # token shift is like a a very efficient 1D convolution with kernel size 2, similar to undilated causal conv in WaveNet
        # this gives each head the ability to choose which parts of the time-series to pay attention to
        # it acts like a vertical forget gate between layers, choosing which parts of the recent past to accrue and which to ignore

        # we want the token embeddings shifted over by one towards the past
        # to get this, we take the last token embedding processed and append all but one of the current token embeddings to it
        # (the last token embedding processed is what's stored in the x_state)
        x_shifted_one_to_the_past = torch.cat((x_state.unsqueeze(-2), x[:,:-1]), dim=1)

        # token shift is just a learned linear interpolation between the current and previous token embeddings in the sequence
        # this is done by lerping between x and the shifted x we just calculated
        # note that it is a per-channel learned interpolation amount, not just a single value per head
        # PLEASE NOTE THAT THE DIRECTION OF THE LERP CHANGED IN RWKV-6
        x_in = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_in)
        x_gate = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_gate)

        # project to 3.5x larger hidden dimension
        # this is 4x for vanilla transformers FFN, but it's typical to reduce it when adding new parameters 
        #  to allow comparison models with the same number of total parameters for a given d_model, n_layer
        # if you drop it down to 3.5x you end up making up for the extra parameters in the gate
        hidden = self.W_in(x_in)

        # relu^2 activation function
        hidden = torch.square(torch.relu(hidden))

        # project back out to d_model
        out = self.W_out(hidden)

        # apply sigmoid gate
        gate = self.gate(x_gate)
        out = out * torch.sigmoid(gate)

        return hidden_state_in + out, x_state_out