Qwen3-0.6B From Scratch
Hi, I'm Sidhant.
This is my first attempt at writing a technical blog. While the main goal here is to strengthen my own understanding, I hope you'll find it helpful too.
Over the past couple of months, I've been working with RAG and AI agents. Along the way, I realised my PyTorch skills had gotten a bit rusty. Recently, I came across this notebook by Sebastian Raschka, which inspired me to finally step into blogging—while also brushing up on some core concepts.
All the code for this blog will be at https://github.com/sidmanale643/Qwen-0.6B-from-scratch
Background
Today, we'll be focusing on the 0.6 Billion parameter Dense model, with particular attention to its use of RoPE, RMS Norm and GQA. Thanks to its compact size and architectural choices, this model is ideal for educational purposes—easy to implement and train from scratch on consumer hardware. Despite being one of the smallest open-source models available, it remains remarkably capable.
Key design features include:
- Group Query Attention
- Rotary Positional Embeddings (RoPE)
- SiLU Activations
- Mixture-of-Experts (MoE) variant as an alternative to the standard Feed-Forward block
For this post we will dive into the Dense variant—and save the MoE architecture for another time.
Config
Positional Encodings
The Self-Attention mechanism, like the dot product, is permutation-invariant. This means that if the order of elements in the input changes, the same change is reflected in the order of the outputs—but the values themselves remain unaffected.
With Absolute Positional Encodings (APE), much more importance is given to the absolute positions of tokens rather than their relative distances. However, the relative distance between tokens carries more meaningful information than their absolute positions, which APE does not capture effectively.
Before Rotary Positional Embeddings (RoPE), positional information was directly added to token embeddings. This approach risked polluting the semantic representations of tokens.
RoPE
First introduced in the RoFormer paper. In RoPE, the angle between tokens encodes their relative distance within a sequence. Unlike APE, only the Query and Key matrices are rotated, preserving the semantic meaning of embeddings while capturing relative positions.
Instead of encoding absolute position directly by adding a vector we drew from sinusoidal functions of slowly decreasing frequencies, we encode the relative position by multiplying each pair with the rotation matrix.
Application
Suppose we have a matrix of Q shape ([5, 8]) (sequence_length, head_dimension) that we want to rotate:
To "rotate" a pair of embeddings means to find a rotations matrix M' such that:
The Rotation Matrix is a matrix that rotates a vector and the angle by which it rotates depends on the position of the vector in the sequence.
Where angle:
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated0
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated1
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated2
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated3
Step 1: Divide the dimensions into embedding_dimension // 2
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated4
Step 2: Calculate Frequencies
i = 0:class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output1
i = 2: class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output2
i = 4: class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output3
i = 6: class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output4
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
Step 3: Calculate Angles
Formula:class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output5
For each position, multiply by each frequency:
#### Position 0:
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output6
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output7
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output8
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output9
#### Position 1:
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)0
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)1
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)2
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)3
#### Position 2:
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)4
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)5
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)6
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)7
#### Position 3:
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)8
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)9
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]0
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]1
#### Position 4:
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]2
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]3
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]4
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]5
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated5
Step 4: Create Rotation Matrices
#### Angles for position 1:
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]6
Since we have 4 pairs of dimensions, we need 4 rotation matrices:
#### Rotation Matrix Template:
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated6
#### For Position #0:
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated7
#### For Position #1:
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated8
#### And so on for Positions #2, #3 and #4
Step 5: Rotate the Q matrix by multiplying by rotation matrices
#### For Position #0:
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated9
#### For Position #1:
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta0
and so on for Positions #2, #3 and #4
Rotated Q
#### Q Original:
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta1
#### Q after RoPE:
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta2
Similarly the Keys matrix is also rotated and then both of them are used in the attention mechanism.
NOTE: RoPE is applied after the Q and K are split into individual heads
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta3
We calculate the RoPE parameters—sine and cosine values used to rotate the embeddings beforehand. Instead of recomputing these for each sequence, we precompute the angles up to the maximum context length (4096 in our case) and store them as PyTorch buffers.
PyTorch buffers are persistent tensors that behave like constants: they are not trainable parameters but are saved and moved with the model.
def calculate_sin_cos():
i = torch.arange(0, config.head_dim, 2).float()
freqs = 1.0 / (config.base ** (i / config.head_dim))
positions = torch.arange(config.max_ctx_length).float()
angles = positions[:, None] * freqs[None, :]
# Duplicate angles to match full head_dim for proper RoPE rotation
angles = torch.cat([angles, angles], dim=1)
return torch.sin(angles), torch.cos(angles)
def rotate(x, sin, cos):
batch_size, num_heads, seq_len, head_dim = x.shape
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
rotation_matrix = torch.cat((-x2, x1), dim=-1)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
x_rotated = (x * cos) + (rotation_matrix * sin)
return x_rotated
RMS Norm
RMSNorm stands for Root Mean Square Normalization. It is a normalization technique designed to stabilize and accelerate training by normalizing inputs based on their root mean square (RMS) values.
Normalization helps us avoid scenarios like:
- Exploding Gradients: When gradient values grow exponentially during backpropagation, causing training to become unstable
- Vanishing Gradients: When gradient values become too small, preventing effective learning in deeper layers
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta4
#### Where:
Gamma (Shift Parameter): a multiplicative parameter for scaling the distribution of the normalized values Beta (Scale Parameter): a additive parameter for shifting the distribution of the normalized values Epsilon (Stability Value): a small value (typically 1e-6) that prevents division by zero when the RMS approaches zeroclass RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta
SiLU Activation Function
Stands for Sigmoid Linear Unit. It is also known as Swish when parameterised. The function demonstrates non-monotonic behaviour, meaning it can decrease for certain input ranges before increasing again.
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta5
Where sigmoid is:
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta6
therefore:
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta7
Note: SiLU is the same as Swish when
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]7 in frequencies = [1.0000, 0.1000, 0.0100, 0.0010]8
SiLU vs ReLU
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta8
The key difference between ReLU and SiLU lies in how they handle negative values. ReLU sets all negative inputs to zero, which can lead to the dying neuron problem, where certain neurons stop activating altogether during training. In contrast, SiLU (Sigmoid Linear Unit) allows small negative values to pass through, maintaining a smoother gradient flow. This helps reduce the risk of inactive neurons and can improve training stability.
SiLU vs GELU
class RMSNORM(nn.Module):
def __init__(self, embed_dim, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_dim))
self.beta = nn.Parameter(torch.zeros(embed_dim))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(dim=-1, keepdim=True)
rms = torch.sqrt(variance + self.eps)
return self.gamma * (x / rms) + self.beta9
Both GELU and SiLU are non-monotonic and smooth activation functions that combine linear and non-linear components. However, SiLU is computationally more efficient. Interestingly, research has shown that GELU can be approximated as a scaled variant of SiLU:
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output0
Grouped Query Attention
Multi Head Attention (MHA)
In standard MHA each head has its own Query, Key and Value Matrices. Although MHA is the most accurate, having as many K and V matrices as the number of heads introduces a huge memory overhead. This increases both training and inference costs in terms of storage, speed, and efficiency.
MQA
Each head has a separate Query matrix while there is only 1 pair of Key and Value matrix which are shared across all the heads. This reduces the number of parameters significantly. This makes MQA extremely fast and reduces the memory usage exponentially. However, this also comes with a tradeoff with accuracy due to the reduction in parameters.
GQA
GQA represents an elegant middle ground between MHA and MQA. Instead of each head having its own K/V (MHA) or all heads sharing one K/V (MQA), GQA groups multiple query heads to share the same K/V projections.
#### Benefits of GQA:
GQA delivers comprehensive improvements across multiple dimensions:
- Reduced memory usage - Fewer K/V matrices mean lower memory requirements
- Fewer parameters - Streamlined architecture without excessive redundancy
- Reduced computation - Less matrix operations during forward and backward passes
- Faster training - Improved throughput during model training
- Faster inference - Quicker response times in production environments
GQA achieves these efficiency gains while maintaining minimal to low performance degradation compared to full MHA.
class GQA(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
# Q and K normalization layers
self.q_norm = RMSNORM(config.n_heads * config.head_dim)
self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
def forward(self, x, sin, cos):
b, seq_len, _ = x.size()
# Linear projections
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# Apply normalization to Q and K
Q = self.q_norm(Q)
K = self.k_norm(K)
# Reshape and transpose to (batch, heads, seq_len, head_dim)
Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
Q = rotate(Q, sin, cos)
K = rotate(K, sin, cos)
# Repeat K and V to match Q heads (for grouped attention)
K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
# Compute attention scores
attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Apply mask (set upper triangular to -inf)
masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Apply softmax
# softmax(-inf) = 0
attention_weights = torch.softmax(masked_attention_scores, dim=-1)
# Apply attention to values
attention_output = attention_weights @ V
# Reshape back to (batch, seq_len, n_heads * head_dim)
attention_output = attention_output.transpose(1, 2).contiguous().view(
b, seq_len, config.n_heads * config.head_dim
)
# Final projection
output = self.out_proj(attention_output)
return output
Finally: The Model
QWEN(
(embeddings): Embedding(151936, 1024)
(blocks): ModuleList(
(0-27): 28 x Block(
(rms_1): RMSNORM()
(mgqa): GQA(
(w_q): Linear(in_features=1024, out_features=2048, bias=False)
(w_k): Linear(in_features=1024, out_features=1024, bias=False)
(w_v): Linear(in_features=1024, out_features=1024, bias=False)
(out_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): RMSNORM()
(k_norm): RMSNORM()
)
(rms_2): RMSNORM()
(ffn): FFN(
(layer_1): Linear(in_features=1024, out_features=3072, bias=False)
(layer_2): Linear(in_features=1024, out_features=3072, bias=False)
(layer_3): Linear(in_features=3072, out_features=1024, bias=False)
)
)
)
(rms_3): RMSNORM()
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)