Building a Custom Transformer Model with HuggingFace Integration and Generation Capability

In this comprehensive guide, we’ll explore how to build a custom Transformer-based language model from scratch using PyTorch and integrate it seamlessly with HuggingFace’s Transformers library. We’ll cover creating custom layers, attention mechanisms, and configuration classes, ultimately enabling the model to use HuggingFace’s powerful generate() functionality.

Overview

Building your Transformer model involves several components:

  • RMSNorm
  • FeedForward Network
  • Rotary Positional Embeddings (RoPE)
  • Multi-Head Attention
  • Transformer Blocks
  • Integration with HuggingFace’s PreTrainedModel and GenerationMixin

Let’s dive in!

Step-by-Step Implementation

1. RMSNorm

Root Mean Square Layer Normalization (RMSNorm) is an efficient alternative to LayerNorm used in modern transformer architectures. Its defined as this :

x' = \frac{x}{√(1/d ∑x_i² + ε)} ⊙ W
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.eps = 1e-5
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean_square = (x ** 2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(mean_square + self.eps)
        return x * self.weight

2.FeedForward Network

This network introduces non-linearity using the SiLU activation:

FFN(x) = W_3^T (silu(W_1^T x) ⊙ W_2^T x)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4
        self.fc1 = nn.Linear(dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(dim, hidden_dim, bias=False)
        self.fc3 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.fc3(F.silu(self.fc1(x)) * self.fc2(x))

3.Rotary Positional Embeddings (RoPE)

RoPE effectively encodes positional information. You guys can check my other blog to understand it better :D

Q'_i = Q_i cosθ_i - Q_j sinθ_i K'_i = K_i cosθ_i + K_j sinθ_i
def precompute_rope(head_dim, seq_len, theta=10000):
    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    positions = torch.arange(seq_len)
    angles = positions[:, None] * inv_freq[None, :]
    angles = torch.cat([angles, angles], dim=1)
    return torch.cos(angles), torch.sin(angles)

def compute_rope(x, cos, sin):
    x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:]
    return (x * cos) + torch.cat((-x2, x1), dim=-1) * sin

4.MultiHead Attention

MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
lass MultiHeadAttention(nn.Module):
    def __init__(self, dim, head_dim, seq_len):
        super().__init__()
        self.n_heads = dim // head_dim
        self.head_dim = head_dim

        self.wq = nn.Linear(dim, dim, bias=False)
        self.wk = nn.Linear(dim, dim, bias=False)
        self.wv = nn.Linear(dim, dim, bias=False)
        self.wo = nn.Linear(dim, dim, bias=False)

        self.register_buffer("mask", 
                            torch.triu(
                                torch.full((seq_len, seq_len), -float('inf'))
                                , 1))
        cos, sin = precompute_rope(head_dim, seq_len)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

    def forward(self, x, attention_mask=None):
        batch, seq_len, _ = x.shape
        q, k, v = [proj(x).view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
                   for proj in (self.wq, self.wk, self.wv)]

        q = compute_rope(q, self.cos, self.sin)
        k = compute_rope(k, self.cos, self.sin)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.mask[:seq_len, :seq_len]

        if attention_mask is not None:
            padding_mask = ~attention_mask[:, None, None, :].boll()
            attn.masked_fill_(padding_mask, float('-inf'))

        attn_output = torch.softmax(attn, dim=-1) @ v
        return self.wo(attn_output.transpose(1, 2).reshape(batch, seq_len, -1))

5.Transformer Block

Combines all previous components:

class TransformerBlock(nn.Module):
    def __init__(self, dim, hidden_dim, head_dim, seq_len):
        super().__init__()
        self.norm1, self.norm2 = RMSNorm(dim), RMSNorm(dim)
        self.attn = MultiHeadAttention(dim, head_dim, seq_len)
        self.ff = FeedForward(dim, hidden_dim)

    def forward(self, x, attention_mask=None):
        x += self.attn(self.norm1(x), attention_mask)
        return x + self.ff(self.norm2(x))

6. Now Let’s Integrate with HugginFace

We will create the model config and the model class now:

from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin

class MainModelConfig(PretrainedConfig):
    model_type = "custom_transformer"
    def __init__(self, dim=512, hidden_dim=2048, head_dim=64, seq_len=1024,
                 num_layers=6, vocab_size=30522, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.head_dim =  head_dim
        self.seq_len = seq_len
        self.num_layers = num_layers 
        self.vocab_size =  vocab_size

class MainModel(PreTrainedModel, GenerationMixin):
    config_class = MainModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.token_emb = nn.Embedding(config.vocab_size, config.dim)
        self.blocks = nn.ModuleList(
            [TransformerBlock(config.dim, 
                              config.hidden_dim,
                              config.head_dim, 
                              config.seq_len)
                   for _ in range(config.num_layers)])
        
        self.norm = RMSNorm(config.dim)
        self.out = nn.Linear(config.dim, config.vocab_size)
        self.post_init()

    def forward(self, input_ids, attention_mask=None, labels=None):
        x = self.token_emb(input_ids)
    
        for block in self.blocks:
            x = block(x, attention_mask)
    
        logits = self.out(self.norm(x))
    
        return CausalLMOutput(loss=None, logits=logits)

Let’s Test the generation :D

I will use my local weights to test it

config = MainModelConfig()
model = MainModel.from_pretrained("path"config=config).cuda()

tokenizer = AutoTokenizer.from_pretrained("gpt2")
prompt = "best football team in madrid is"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

output = model.generate(input_ids, max_length=150, temperature=0.8, top_k=10, top_p=0.95)

print(tokenizer.decode(output[0], skip_special_tokens=True))