LoRa (Low-Rank Adaptation of Large Language Models) is a technique used to fine-tune pre-trained large language models efficiently. Instead of updating all the parameters of the model during fine-tuning, LoRa freezes the majority of the model weights and only introduces trainable low-rank matrices (rank-limited approximations) to model weight updates. This significantly reduces the computational resources and storage required for fine-tuning while maintaining performance, making it ideal for tasks where large-scale model adaptation is needed with limited data or compute.
class LoRA(nn.Module):
def __init__(self,dim_in,dim_out,rank=1,alpha=1,device="cpu"):
super().__init__()
# Creating the A and B
# dim_out = K
# dim_in = D
self.lora_a = nn.Parameter(torch.zeros((rank,dim_out)).to(device))
self.lora_b = nn.Parameter(torch.zeros((dim_in,rank)).to(device))
# Get A from a a normal distribution
nn.init.normal_(self.lora_a,mean=0,std=1)
# Scale factor
self.scale = alpha/rank
self.enabled = True # To enable disable
def forward(self,original_weights):
if self.enabled is True:
return original_weights + torch.matmul(self.lora_b,self.lora_a).view_as(original_weights) * self.scale
else:
return original_weights
def linear_param(layer,device,rank=1,alpha=1):
dim_in,dim_out = layer.weight.shape # Get the layer dimension
return LoRA(dim_in,dim_out,rank=rank,alpha=1,device=device)
# Apply lora function for each layer
def apply_lora(model,device):
for name,layer in model.named_modules():
if isinstance(layer,nn.Linear): # if layer is nn.Linear
parametrize.register_parametrization(layer,"weight",
linear_param(layer,device))
# register a new parametrization with out function
# Simple remove lora function where we remove the prametrization
def remove_lora(model):
for name, layer in model.named_modules():
if isinstance(layer, nn.Linear):
parametrize.remove_parametrizations(layer, 'weight', leave_parametrized=False)
# Enable/Disable function
def enable_disable_lora(model,enabled=True):
for name,layer in model.named_modules():
if isinstance(layer,nn.Linear):
layer.parametrizations["weight"][0].enabled = enabled
# Function where we freeze the model weights
def freeze_model(model,freeze=True):
for param in model.parameters():
param.requires_grad = freeze
Implementation based on this video: https://www.youtube.com/watch?v=PXWYUTMt-AU
Paper links: https://arxiv.org/abs/2106.09685