LoRa: Low-Rank Adaptation for Efficient Fine-Tuning of Large Language Models

LoRa (Low-Rank Adaptation of Large Language Models) is a technique used to fine-tune pre-trained large language models efficiently. Instead of updating all the parameters of the model during fine-tuning, LoRa freezes the majority of the model weights and only introduces trainable low-rank matrices (rank-limited approximations) to model weight updates. This significantly reduces the computational resources and storage required for fine-tuning while maintaining performance, making it ideal for tasks where large-scale model adaptation is needed with limited data or compute.

\Delta W = B A^T W_{new} = W_{old} + \Delta W
Where we have: A \in \R ^{K×R} , B \in \R ^{D×R} , W \in \R ^{D×K}
The advantage of this low-rank update is that it maintains the expressive power of the model while requiring fewer parameters to update, making it efficient for fine-tuning large models.

Implementation

Main LoRA class:

class LoRA(nn.Module):
    def __init__(self,dim_in,dim_out,rank=1,alpha=1,device="cpu"):
        super().__init__()

        # Creating the A and B
        # dim_out = K
        # dim_in = D 
        self.lora_a = nn.Parameter(torch.zeros((rank,dim_out)).to(device))
        self.lora_b = nn.Parameter(torch.zeros((dim_in,rank)).to(device))

        # Get A from a a normal distribution
        nn.init.normal_(self.lora_a,mean=0,std=1)

        # Scale factor
        self.scale = alpha/rank
        self.enabled = True # To enable disable

    def forward(self,original_weights):
        if self.enabled is True:
            return original_weights + torch.matmul(self.lora_b,self.lora_a).view_as(original_weights) * self.scale
        else:
            return original_weights

Helper functions:

def linear_param(layer,device,rank=1,alpha=1):
    dim_in,dim_out = layer.weight.shape # Get the layer dimension
    return LoRA(dim_in,dim_out,rank=rank,alpha=1,device=device) 


# Apply lora function for each layer
def apply_lora(model,device):
    for name,layer in model.named_modules():
        if isinstance(layer,nn.Linear): # if layer is nn.Linear
            parametrize.register_parametrization(layer,"weight", 
            linear_param(layer,device))
            # register a new parametrization with out function 


# Simple remove lora function where we remove the prametrization
def remove_lora(model):
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Linear):
            parametrize.remove_parametrizations(layer, 'weight', leave_parametrized=False)
            

# Enable/Disable function
def enable_disable_lora(model,enabled=True):
    for name,layer in model.named_modules():
        if isinstance(layer,nn.Linear):
            layer.parametrizations["weight"][0].enabled = enabled

# Function where we freeze the model weights
def freeze_model(model,freeze=True):
    for param in model.parameters():
        param.requires_grad = freeze

Implementation based on this video: https://www.youtube.com/watch?v=PXWYUTMt-AU
Paper links: https://arxiv.org/abs/2106.09685