RoPE: Rotary Position Embedding for Flexible Positional Encoding in Transformers

RoPE (Rotary Position Embedding) is a method for encoding positional information in transformers. Unlike traditional approaches like absolute or relative position encodings, RoPE applies a rotation matrix to the query and key vectors in self-attention, encoding positional information directly in the attention mechanism. This allows for smoother extrapolation to longer sequences and more flexible handling of positional data.

Formulation (from the paper)

In transformer models, self-attention scores are usually calculated based on token embeddings. To better capture relationships based on distance, we want these scores to also reflect the relative positions of tokens (i.e., how far apart they are). Instead of just using token embeddings, we use a function that considers both the token embeddings and their relative positions to compute these scores. This helps the model understand the relationships between tokens more effectively.

We hope that the inner product encodes position information only in the relative form:

\lang f_q(x_m,m),f_k(x_n,n) \rang = g(x_m,x_n,m-n) q = f_q(x_m,0) = W_q x_m k = f_k(x_n,0) = W_k x_n
and for m/n position we get :
q = f_q(x_m,m) = (W_q x_m)e^{imθ} k = f_k(x_n,n) = (W_k x_ne)e^{inθ} g(x_m,x_n,m-n) = Re[(W_q x_m)(W_k x_n)^* e^{i(m-n)θ}]

Consult the paper for the full math proof and the computational efficient rotarty matrix : https://arxiv.org/abs/2104.09864

Implementation:

We will create the function to precompute the thetas :

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str theta:float = 10000.0):
    # head_dim = dimension of our embeddings
    # seq_len = number of tokens in a sequence
    # thetha = scaling factor 
    
    assert head_dim % 2 == 0, "Dimension must be div by 2"
    # This is requires because RoPE splits dimension into 2 parts for sinusoidal encoding

    theta_numerator = torch.arange(0,head_dim,2).float()
    # array of [0.,2.,4,. ... ,head_dim/2.]

    theta = 1.0/(theta ** (theta_numerator / head_dim)).to(device)
    # we calculate the thetas

    m = torch.arange(seq_len,device=device)
    # array contains integers from 0 -> seq_len-1 who are token positions

    freqs = torch.outer(m,theta).float()
    # it creates a matrix in which every row is a different position 
    # and each column is a different dimension of the rotary embedding
    # example : 
    #                                                    / 1x1  1x2 1x3  1x4 
    #   outer_product([[1,0],[0,1]] , [[1,2],[3,4]]) =   | 0x1  0x2 0x3  0x4 |
    #                                                    | 0x1  0x2 0x3  0x4 |
    #                                                     1x1  1x2 1x3  1x4 /

    freqs_complex = torch.polar(torch.ones_like(freqs),freqs)
    # torch.polar(abs,angle) | out = abs*cos(angle) + abs*sin(angle)*j 
    # we create the abs to be the exact same as freqs and we want to have the abs = 1

    return freqs_complex

Now we need to create a function to apply them to our embeddings !

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # x = input tensor , x.shape = [batch,seq_len,num_heads,head_dim]
    # freqs_complex = precalculated frequencies

    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2))
    # Let's break it down as this operation might look confusing at first:
    # x.float() converts x to a float type
    # x.reshape(*x.shape[:-1],-1,2)) 
    # *x.shape[:-1] will unpacks the dimensions and we will not include the last one
    # so *x.shape[:-1] = [batch,seq_len,num_heads]
    # x.reshape([batch,seq_len,num_heads],-1,2)
    # now the confusing part is that we will have 2 more dimensions:
    # [batch,seq_len,num_heads,?,2], the "?" dimension will be automaticaly calculated
    # if we do the math we will get [batch,seq_len,num_heads,head_dim/2,2]
    # so you can think that we are adding for the head_dim 2 parts for compelx number
    # as a complex number is a + bi we need to create the a,b from head_dim
    # torch.view_as_complex will return the tensor as a compelx form
    # from pytorch website:
    # "where the last dimension of the input tensor is expected to represent the real and imaginary components of complex numbers."

    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # freqs_complex initial size = [seq_len,head_dim/2]
    # .unsqueeze(0) -> [1,seq_len,head_dim/2] we use it so that the freqs are across batches
    # .unsqueeze(2) -> [1,seq_len,1,head_dim/2] we use it so that can be used across all heads
    # it will match the size of query/key

    x_rotated = x_complex * freqs_complex
    
    x_out = torch.view_as_real(x_rotated)
    # the inverse operation of the .view_as_complex
    # from the [a - ib] -> [a,-b]

    x_out = x_out.reshape(*x.shape)
    # we reshape our tensor to be the exact shape as the original x tensor

    return x_out.type_as(x).to(device)
    # .type_as(x) to be the exact type as x 
    # ( maybe after the operations we would get a float/long tensor and we dont want that !)

Let’s create a simple attention mechanism to showcase the whole code :D

class Attention(nn.Module):
    def __init__(self,args):
        super().__init__()

        self.n_heads = args.n_heads # head_dimens
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # querries
        self.wk = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # keys
        self.wv = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # values
        self.wo = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # output projection

    def forward(self,x,start_pos:int,freqs_complex):
        batch_size,seq_len,_ = x.shape # unpack the shape we will use them

        xq = self.wq(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
        xk = self.wk(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
        xv = self.wv(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
        # We calculate the weights and after that we do the MHA part:
        # from [batch_size,seq_len,emb_dim] - > [batch_size,seq_len,n_heads,head_dim]

        # Now we apply the rotary positional embeddings
        xq = apply_rotary_embeddings(xq,freqs_complex,device=x.device)
        xk = apply_rotary_embeddings(xk,freqs_complex,device=x.device)

        # Transpose 
        xq = xq.transpose(1,2)
        xk = xk.transpose(1,2)
        xv = xv.transpose(1,2)

        # Calculate the Attention:
        scores = torch.matmul(xq,xk.transpose(2,3)) / math.sqrt(self.head_dim) # we can precalculate the denominator
        scores = F.softmax(scores.float(),dim=-1).type_as(x)
        # we apply the softmax function over all dimensions 

        output = torch.matmul(scores,values)
        output = output.transpose(1,2).contiguous().view(batch_size,seq_len,-1)
        # we transpose it back in initial shape
        # .contiguous() makes sure that the tensor is in a contiguous chunk of memory
        # maybe after the calculations the tensor may not be contiguous 
        return self.wo(output)

But where we precompute the thetas ? Well we don’t need to recalculate them every time so if we do it at begining its fine. Example of a transformer body:

class Transformer(nn.Module):
    def __init__(self,args):
        super().__init__()
        
        self.args = args
        self.vocab_size = args.vocab_size # vocab size

        # .... other code
        self.freqs_complex = precompute_theta_pos_frequencies(args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
        # args.dim /self.args.n_heads = embed dim for each attention head

    def forward(self,tokens,start_pos):
        batch_size ,seq_len = tokens.shape
        # ... other code

        freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]

        # ... other code

thanks for reading :D