In this comprehensive guide, we’ll explore how to build a custom Transformer-based language model from scratch using PyTorch and integrate it seamlessly with HuggingFace’s Transformers library. We’ll cover creating custom layers, attention mechanisms, and configuration classes, ultimately enabling the model to use HuggingFace’s powerful generate() functionality.
Building your Transformer model involves several components:
Let’s dive in!
Root Mean Square Layer Normalization (RMSNorm) is an efficient alternative to LayerNorm used in modern transformer architectures. Its defined as this :
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.eps = 1e-5
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
mean_square = (x ** 2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(mean_square + self.eps)
return x * self.weight
This network introduces non-linearity using the SiLU activation:
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or dim * 4
self.fc1 = nn.Linear(dim, hidden_dim, bias=False)
self.fc2 = nn.Linear(dim, hidden_dim, bias=False)
self.fc3 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
return self.fc3(F.silu(self.fc1(x)) * self.fc2(x))
RoPE effectively encodes positional information. You guys can check my other blog to understand it better :D
def precompute_rope(head_dim, seq_len, theta=10000):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
positions = torch.arange(seq_len)
angles = positions[:, None] * inv_freq[None, :]
angles = torch.cat([angles, angles], dim=1)
return torch.cos(angles), torch.sin(angles)
def compute_rope(x, cos, sin):
x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:]
return (x * cos) + torch.cat((-x2, x1), dim=-1) * sin
lass MultiHeadAttention(nn.Module):
def __init__(self, dim, head_dim, seq_len):
super().__init__()
self.n_heads = dim // head_dim
self.head_dim = head_dim
self.wq = nn.Linear(dim, dim, bias=False)
self.wk = nn.Linear(dim, dim, bias=False)
self.wv = nn.Linear(dim, dim, bias=False)
self.wo = nn.Linear(dim, dim, bias=False)
self.register_buffer("mask",
torch.triu(
torch.full((seq_len, seq_len), -float('inf'))
, 1))
cos, sin = precompute_rope(head_dim, seq_len)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
def forward(self, x, attention_mask=None):
batch, seq_len, _ = x.shape
q, k, v = [proj(x).view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
for proj in (self.wq, self.wk, self.wv)]
q = compute_rope(q, self.cos, self.sin)
k = compute_rope(k, self.cos, self.sin)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.mask[:seq_len, :seq_len]
if attention_mask is not None:
padding_mask = ~attention_mask[:, None, None, :].boll()
attn.masked_fill_(padding_mask, float('-inf'))
attn_output = torch.softmax(attn, dim=-1) @ v
return self.wo(attn_output.transpose(1, 2).reshape(batch, seq_len, -1))
Combines all previous components:
class TransformerBlock(nn.Module):
def __init__(self, dim, hidden_dim, head_dim, seq_len):
super().__init__()
self.norm1, self.norm2 = RMSNorm(dim), RMSNorm(dim)
self.attn = MultiHeadAttention(dim, head_dim, seq_len)
self.ff = FeedForward(dim, hidden_dim)
def forward(self, x, attention_mask=None):
x += self.attn(self.norm1(x), attention_mask)
return x + self.ff(self.norm2(x))
We will create the model config and the model class now:
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
class MainModelConfig(PretrainedConfig):
model_type = "custom_transformer"
def __init__(self, dim=512, hidden_dim=2048, head_dim=64, seq_len=1024,
num_layers=6, vocab_size=30522, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.hidden_dim = hidden_dim
self.head_dim = head_dim
self.seq_len = seq_len
self.num_layers = num_layers
self.vocab_size = vocab_size
class MainModel(PreTrainedModel, GenerationMixin):
config_class = MainModelConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
self.blocks = nn.ModuleList(
[TransformerBlock(config.dim,
config.hidden_dim,
config.head_dim,
config.seq_len)
for _ in range(config.num_layers)])
self.norm = RMSNorm(config.dim)
self.out = nn.Linear(config.dim, config.vocab_size)
self.post_init()
def forward(self, input_ids, attention_mask=None, labels=None):
x = self.token_emb(input_ids)
for block in self.blocks:
x = block(x, attention_mask)
logits = self.out(self.norm(x))
return CausalLMOutput(loss=None, logits=logits)
I will use my local weights to test it
config = MainModelConfig()
model = MainModel.from_pretrained("path"config=config).cuda()
tokenizer = AutoTokenizer.from_pretrained("gpt2")
prompt = "best football team in madrid is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
output = model.generate(input_ids, max_length=150, temperature=0.8, top_k=10, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))