RoPE (Rotary Position Embedding) is a method for encoding positional information in transformers. Unlike traditional approaches like absolute or relative position encodings, RoPE applies a rotation matrix to the query and key vectors in self-attention, encoding positional information directly in the attention mechanism. This allows for smoother extrapolation to longer sequences and more flexible handling of positional data.
In transformer models, self-attention scores are usually calculated based on token embeddings. To better capture relationships based on distance, we want these scores to also reflect the relative positions of tokens (i.e., how far apart they are). Instead of just using token embeddings, we use a function that considers both the token embeddings and their relative positions to compute these scores. This helps the model understand the relationships between tokens more effectively.
We hope that the inner product encodes position information only in the relative form:
Consult the paper for the full math proof and the computational efficient rotarty matrix : https://arxiv.org/abs/2104.09864
We will create the function to precompute the thetas :
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str theta:float = 10000.0):
# head_dim = dimension of our embeddings
# seq_len = number of tokens in a sequence
# thetha = scaling factor
assert head_dim % 2 == 0, "Dimension must be div by 2"
# This is requires because RoPE splits dimension into 2 parts for sinusoidal encoding
theta_numerator = torch.arange(0,head_dim,2).float()
# array of [0.,2.,4,. ... ,head_dim/2.]
theta = 1.0/(theta ** (theta_numerator / head_dim)).to(device)
# we calculate the thetas
m = torch.arange(seq_len,device=device)
# array contains integers from 0 -> seq_len-1 who are token positions
freqs = torch.outer(m,theta).float()
# it creates a matrix in which every row is a different position
# and each column is a different dimension of the rotary embedding
# example :
# / 1x1 1x2 1x3 1x4
# outer_product([[1,0],[0,1]] , [[1,2],[3,4]]) = | 0x1 0x2 0x3 0x4 |
# | 0x1 0x2 0x3 0x4 |
# 1x1 1x2 1x3 1x4 /
freqs_complex = torch.polar(torch.ones_like(freqs),freqs)
# torch.polar(abs,angle) | out = abs*cos(angle) + abs*sin(angle)*j
# we create the abs to be the exact same as freqs and we want to have the abs = 1
return freqs_complex
Now we need to create a function to apply them to our embeddings !
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
# x = input tensor , x.shape = [batch,seq_len,num_heads,head_dim]
# freqs_complex = precalculated frequencies
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2))
# Let's break it down as this operation might look confusing at first:
# x.float() converts x to a float type
# x.reshape(*x.shape[:-1],-1,2))
# *x.shape[:-1] will unpacks the dimensions and we will not include the last one
# so *x.shape[:-1] = [batch,seq_len,num_heads]
# x.reshape([batch,seq_len,num_heads],-1,2)
# now the confusing part is that we will have 2 more dimensions:
# [batch,seq_len,num_heads,?,2], the "?" dimension will be automaticaly calculated
# if we do the math we will get [batch,seq_len,num_heads,head_dim/2,2]
# so you can think that we are adding for the head_dim 2 parts for compelx number
# as a complex number is a + bi we need to create the a,b from head_dim
# torch.view_as_complex will return the tensor as a compelx form
# from pytorch website:
# "where the last dimension of the input tensor is expected to represent the real and imaginary components of complex numbers."
freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
# freqs_complex initial size = [seq_len,head_dim/2]
# .unsqueeze(0) -> [1,seq_len,head_dim/2] we use it so that the freqs are across batches
# .unsqueeze(2) -> [1,seq_len,1,head_dim/2] we use it so that can be used across all heads
# it will match the size of query/key
x_rotated = x_complex * freqs_complex
x_out = torch.view_as_real(x_rotated)
# the inverse operation of the .view_as_complex
# from the [a - ib] -> [a,-b]
x_out = x_out.reshape(*x.shape)
# we reshape our tensor to be the exact shape as the original x tensor
return x_out.type_as(x).to(device)
# .type_as(x) to be the exact type as x
# ( maybe after the operations we would get a float/long tensor and we dont want that !)
Let’s create a simple attention mechanism to showcase the whole code :D
class Attention(nn.Module):
def __init__(self,args):
super().__init__()
self.n_heads = args.n_heads # head_dimens
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # querries
self.wk = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # keys
self.wv = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # values
self.wo = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # output projection
def forward(self,x,start_pos:int,freqs_complex):
batch_size,seq_len,_ = x.shape # unpack the shape we will use them
xq = self.wq(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
xk = self.wk(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
xv = self.wv(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
# We calculate the weights and after that we do the MHA part:
# from [batch_size,seq_len,emb_dim] - > [batch_size,seq_len,n_heads,head_dim]
# Now we apply the rotary positional embeddings
xq = apply_rotary_embeddings(xq,freqs_complex,device=x.device)
xk = apply_rotary_embeddings(xk,freqs_complex,device=x.device)
# Transpose
xq = xq.transpose(1,2)
xk = xk.transpose(1,2)
xv = xv.transpose(1,2)
# Calculate the Attention:
scores = torch.matmul(xq,xk.transpose(2,3)) / math.sqrt(self.head_dim) # we can precalculate the denominator
scores = F.softmax(scores.float(),dim=-1).type_as(x)
# we apply the softmax function over all dimensions
output = torch.matmul(scores,values)
output = output.transpose(1,2).contiguous().view(batch_size,seq_len,-1)
# we transpose it back in initial shape
# .contiguous() makes sure that the tensor is in a contiguous chunk of memory
# maybe after the calculations the tensor may not be contiguous
return self.wo(output)
But where we precompute the thetas ? Well we don’t need to recalculate them every time so if we do it at begining its fine. Example of a transformer body:
class Transformer(nn.Module):
def __init__(self,args):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size # vocab size
# .... other code
self.freqs_complex = precompute_theta_pos_frequencies(args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
# args.dim /self.args.n_heads = embed dim for each attention head
def forward(self,tokens,start_pos):
batch_size ,seq_len = tokens.shape
# ... other code
freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]
# ... other code
thanks for reading :D