In this blog I will discuss about the RWKV models (v5.2, v6 ,v7). I will use code snippets from https://github.com/SmerkyG/RWKV_Explained
We will build the model from the main class and add other classes when we encounter them.
# Main class of the model
class RWKV(torch.nn.Module):
def __init__(self, cfg:Config = Config()):
super().__init__()
self.cfg = cfg # Configuration
self.embed = nn.Parameter(torch.empty(cfg.vocab_size, cfg.d_model))
self.embed_norm = nn.LayerNorm(cfg.d_model)
# Embeds are made by creating an matrix of size [vocab_size,d_model] # Which is exactly a linear layer
# After that we normalize it
self.layers = nn.ModuleList([Layer(cfg, layer_id) for layer_id in range(cfg.n_layers)])
# Creates ModuleList of Layers that are created by the upcoming function
# We also add the layer_id.
self.lm_head_norm = nn.LayerNorm(cfg.d_model)
# Normalizing the logits of the model
self.lm_head_unembed = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
# After that we multiply the logits with a linearlayer so that we can have the probabilities over all vocab values
# input tensor dimensions:
# x (B,T) -> [Batch_size, timestamp]
# example : [2,20] - > 2 batches and each has 20 timestamp (sentence of 20 tokens)
def forward(self, x : Tensor, s : list[LayerState]|None = None):
# calculate embeddings for each incoming token, then normalize them
# see https://github.com/BlinkDL/SmallInitEmb for details on why we do this normalization
# if you look at some literature on pre-genereated embeddings, you'll see that they are
# often ideally considered to become unit length vectors around the hypersphere
# so starting it as small noise while using a normalized version instantly causes this layout,
# allowing them to initially move rapidly around the surface of that hypersphere, and later slowly
x = self.embed_norm(nn.functional.embedding(x, self.embed))
# Embedding the data
s = s or [LayerState(x, self.cfg) for _ in range(self.cfg.n_layers)]
# We create the LayerState which will be saved in a list for each Number of layers.
# We will look after that at the LayerState function to analyze it
# run each layer in succession, passing in the RNN state for that layer
for layer_id, block in enumerate(self.layers): # run each rwkv block
x, s[layer_id] = block(x, s[layer_id])
# normalize the output
x = self.lm_head_norm(x)
# unembed back to dictionary indices
x = self.lm_head_unembed(x)
return x, s
Now let’s look first at the LayerState class to understand the structure of our model. We need to understand how the data is processed comparing it to a “GPT” version. Take in mind that this architecture is in essence a RNN (maybe it will help to understand it better)
class LayerState:
# the recurrent neural network (RNN) state for a layer of RWKV5.2
def __init__(self, x, cfg:Config):
# input tensor dimensions:
# x (B,T)
# C = d_model -> dimension of a model size. Well think about the size for the Matrix Q/K/V in GPT :D
# H = n_heads -> number of 'attention' heads
# K = d_model//n_heads -> will result the embedding size of each head :D
B, T, C, H, K = x.size(0), x.size(1), cfg.d_model, cfg.n_heads, cfg.d_model // cfg.n_heads
V = K
# a (B,C) size tensor representing latest time mixer token embedding processed
# responsible for capturing relationship between tokens
self.time_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)
# an (B,H,K,V) size tensor representing a decaying token embedding memory for each head, where H=number_of_heads, K=key_dim_per_head, V=value_dim_per_head
self.kv_state = torch.zeros(B,H,K,V,dtype=torch.float32,device=x.device)
# a (B,C) size tensor representing latest channel mixer token embedding processed
self.channel_mixer_x_state = torch.zeros(B,C,dtype=x.dtype,device=x.device)
Now lets go deeper in the model
Thiss class will calculate time-mixing and channel-mixing. ~ add picture ~
class Layer(nn.Module):
def __init__(self, cfg:Config, layer_id:int):
super().__init__()
self.time_mixer = TimeMixer(cfg, layer_id) # pass the configuration + layer_id
self.channel_mixer = ChannelMixer(cfg, layer_id) # same as above
def forward(self, x : Tensor, s : LayerState):
# X will have the shape : [batch_size,timestamp,dimension_embed]
x, s.time_mixer_x_state, s.kv_state = self.time_mixer(x, s.time_mixer_x_state, s.kv_state)
# We update the x, The time-mixer-state and kv-state
# We will introduce them down
x, s.channel_mixer_x_state = self.channel_mixer(x, s.channel_mixer_x_state)
# We update our x and channel-mixer-state
# return x and the state
return x, s
Let’s start with the time_mixer class
class TimeMixer(nn.Module):
def __init__(self, cfg:Config, layer_id:int):
super().__init__()
self.cfg = cfg # saving the config
d_model = cfg.d_model # dimension of the embeding
d_head = d_model // cfg.n_heads # dimension of each head
self.prenorm = nn.LayerNorm(d_model) # Pre Normalisation
self.tokenshift_receptance = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
self.tokenshift_key = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
self.tokenshift_value = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
self.tokenshift_gate = nn.Parameter(torch.empty(1, 1, d_model)) # [1,1,d_model]
self.receptance = nn.Linear(d_model, d_model, bias=False) #
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.gate = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
# Weighted Matrices calculated
# per-channel boost for current embedding
self.bonus = nn.Parameter(torch.ones(cfg.n_heads, d_head))
# per-channel decay multipliers applied to kv_state at each timestep
self.decay = nn.Parameter(torch.ones(cfg.n_heads, d_head))
# Norm
self.group_norm = nn.GroupNorm(cfg.n_heads, d_model, eps=64e-5)
Lets dive down in the forward pass
def forward(self, hidden_state_in : Tensor, x_state : Tensor, kv_state : Tensor):
# x (B,T,C), x_state (B,C), kv_state (B,H,K,V)
x = self.prenorm(hidden_state_in) # prenorm the input
x_state_out = x[:, -1] # last token embedding saved
B, T, C, H, K = x.size(0), x.size(1), self.cfg.d_model, self.cfg.n_heads, self.cfg.d_model // self.cfg.n_heads
# we want the token embeddings shifted over by one towards the past
# to get this, we take the last token embedding processed and append all but one of the current token embeddings to it
# (the last token embedding processed is what's stored in the x_state)
x_shifted_one_to_the_past = torch.cat((x_state.unsqueeze(-2), x[:,:-1]), dim=1)
# x_state.unsqueeze(-1) will make its size (Batch_size,1,State)
# so it will be one timestamp
# we concatenate with x[:,:-1] which will be all tokens without the last one
# token shift the incoming token embeddings for the receptance, key, value, and gate
# PLEASE NOTE THAT THE DIRECTION OF THE LERP CHANGED IN RWKV-6
x_receptance = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_receptance)
x_key = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_key)
x_value = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_value)
x_gate = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_gate)
# runt the linear interpolation:
# torch.lerp(input,end,weight):
# out[i] = start[i] + weight[i]*(end[i]-start[i])
# So for example in our case lets say the following:
# x = [1,2,3,4,5]
# x_shifted = [10,2,3,4]
# we will iterate over them and interpolate
# self.tokenshift = [1,1,d_model]
# we have x of the form : [batch_size,Tokens,d_model]
# So we interpolate only on the channels/embeddings
# This is for key,value,receptance,gate
# the extra dimensions are being added here to enable matrix multiplications per timestep
r = self.receptance(x_receptance).view(B,T,H,1,K) # BTH1K
k = self.key(x_key).view(B,T,H,K,1) # BTHK1
v = self.value(x_value).view(B,T,H,1,K) # BTH1K
gate = self.gate(x_gate) # BTC
# this forces the decays to end up in the range 0...1 using a nicely differentiable function
decay = torch.exp(-torch.exp(self.decay.float())) # HK
out = torch.empty(B, T, H, K, dtype=x.dtype, device=x.device)
# output = [batch_size,timestamp,n_heads,head_dim]
for t in range(T):
out[:,t], kv_state = TimeMixer.single_timestep(r[:,t], k[:,t], v[:,t],
self.bonus,
decay,
kv_state)
# We will look into the single_timestep processing function in a second
# So we iterate with t from [0->nr_timesteps]
# out[:,t] -> out[all_batches,specific_timestep]
# kv_state is updated
# single_timestep(r[:,t],k[:,t],v[:,t],self.bonus,decay,kv_state)
# r[:,t] -> [all_batches,specific_timestep,n_heads,1,head_dim]
# example : r[:,3] = [1,3,16,1,64]
# apply group normalization to each head and recombine the heads
out = self.group_norm(out.view(B*T, C)).view(B, T, C) # BTC
# apply silu gate to the output
out = out * nn.functional.silu(gate) # BTC
# project the output
out = self.output(out) # BTC
return hidden_state_in + out, x_state_out, kv_state
Now lets see the actual time_step processing function
@staticmethod
def single_timestep(r, k, v, u, w, kv_state):
original_dtype = r.dtype
B, H, K, V = kv_state.shape
# get the shapes
# transform inputs from BHK and put everything in float format for higher precision
r = r.float().view(B, H, 1, K)
w = w.float().view(1, H, K, 1)
k = k.float().view(B, H, K, 1)
v = v.float().view(B, H, 1, V)
u = u.float().view(1, H, K, 1)
kv = k @ v # BHK1 @ BH1V = BHKV
# matrix multiplication between keys and values
# K: [batch,Nr_heads,Head_dim,1]
# V: [batch,Nr_heads,1,Head_dim]
# start with the existing kv state
y = kv_state # BHKV
# apply the u boost to the current k @ v and add it to that
y = y + kv * u # BHKV + BHKV * 1HK1 = BHKV
# apply receptance to that whole result
out = r @ y # BH1K @ BHKV = BH1V
# remove an extra useless dimension from the output
out = out.squeeze(-2).to(original_dtype) # BHV
# finally, decay the kv state and add in the latest k @ v
kv_state = kv_state * w # BHKV * BHK1 = BHKV
kv_state = kv_state + kv # BHKV + BHKV = BHKV
return out, kv_state # BHV, BHKV
Channel Mixing formulas ar not that hard and are very easy to follow. We do a Lerp , weighted relu squared and we use the sigmoid of the function:
class ChannelMixer(nn.Module):
def __init__(self, cfg:Config, layer_id:int):
super().__init__()
self.cfg = cfg
self.prenorm = nn.LayerNorm(cfg.d_model) #
self.tokenshift_in = nn.Parameter(torch.empty(1, 1, cfg.d_model))
self.tokenshift_gate = nn.Parameter(torch.empty(1, 1, cfg.d_model))
d_ffn = cfg.d_ffn or int(cfg.d_model * 3.5)//32*32
# FFN multiplyer
self.W_in = nn.Linear(cfg.d_model, d_ffn, bias=False)
self.W_out = nn.Linear(d_ffn, cfg.d_model, bias=False)
self.gate = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
def forward(self, hidden_state_in : Tensor, x_state : Tensor): # x (B,T,C), x_state (B,C)
x = self.prenorm(hidden_state_in)
x_state_out = x[:, -1]
# token shift the incoming token embeddings for both the input projection and gate
# token shift is like a a very efficient 1D convolution with kernel size 2, similar to undilated causal conv in WaveNet
# this gives each head the ability to choose which parts of the time-series to pay attention to
# it acts like a vertical forget gate between layers, choosing which parts of the recent past to accrue and which to ignore
# we want the token embeddings shifted over by one towards the past
# to get this, we take the last token embedding processed and append all but one of the current token embeddings to it
# (the last token embedding processed is what's stored in the x_state)
x_shifted_one_to_the_past = torch.cat((x_state.unsqueeze(-2), x[:,:-1]), dim=1)
# token shift is just a learned linear interpolation between the current and previous token embeddings in the sequence
# this is done by lerping between x and the shifted x we just calculated
# note that it is a per-channel learned interpolation amount, not just a single value per head
# PLEASE NOTE THAT THE DIRECTION OF THE LERP CHANGED IN RWKV-6
x_in = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_in)
x_gate = torch.lerp(x_shifted_one_to_the_past, x, self.tokenshift_gate)
# project to 3.5x larger hidden dimension
# this is 4x for vanilla transformers FFN, but it's typical to reduce it when adding new parameters
# to allow comparison models with the same number of total parameters for a given d_model, n_layer
# if you drop it down to 3.5x you end up making up for the extra parameters in the gate
hidden = self.W_in(x_in)
# relu^2 activation function
hidden = torch.square(torch.relu(hidden))
# project back out to d_model
out = self.W_out(hidden)
# apply sigmoid gate
gate = self.gate(x_gate)
out = out * torch.sigmoid(gate)
return hidden_state_in + out, x_state_out