← Blogs
LLM Architecture

Qwen3-0.6B From Scratch

A walkthrough of the Qwen-3 0.6B architecture, exploring RoPE, RMS Norm, and Grouped Query Attention (GQA).

2025-08-28 PyTorchLLM

Qwen3-0.6B From Scratch

Hi, I'm Sidhant.

This is my first attempt at writing a technical blog. While the main goal here is to strengthen my own understanding, I hope you'll find it helpful too.

Over the past couple of months, I've been working with RAG and AI agents. Along the way, I realised my PyTorch skills had gotten a bit rusty. Recently, I came across this notebook by Sebastian Raschka, which inspired me to finally step into blogging—while also brushing up on some core concepts.

All the code for this blog will be at https://github.com/sidmanale643/Qwen-0.6B-from-scratch

Background

The Qwen Family of models is developed by the Chinese tech giant Alibaba Cloud. The Qwen Models are on the Frontier of Open Source LLMs and give the proprietary models from OpenAI and Anthropic a serious run for their money.

Today, we'll be focusing on the 0.6 Billion parameter Dense model, with particular attention to its use of RoPE, RMS Norm and GQA. Thanks to its compact size and architectural choices, this model is ideal for educational purposes—easy to implement and train from scratch on consumer hardware. Despite being one of the smallest open-source models available, it remains remarkably capable.

Key design features include:

  • Group Query Attention
  • Rotary Positional Embeddings (RoPE)
  • SiLU Activations
  • Mixture-of-Experts (MoE) variant as an alternative to the standard Feed-Forward block

For this post we will dive into the Dense variant—and save the MoE architecture for another time.

Config

\[ \begin{array}{|l|l|} \hline \text{Parameter} & \text{Value} \\ \hline \text{vocabulary\_size} & 151936 \\ \hline \text{embedding\_dimension} & 1024 \\ \hline \text{hidden\_dim} & 3072 \\ \hline \text{n\_layers} & 28 \\ \hline \text{n\_heads} & 16 \\ \hline \text{kv\_heads} & 8 \\ \hline \text{head\_dim} & 128 \\ \hline \text{max\_ctx\_length} & 4096 \\ \hline \text{rope\_base} & 1.0 \times 10^6 \\ \hline \text{eps} & 1 \times 10^{-6} \\ \hline \text{dtype} & \text{torch.float32} \\ \hline \end{array} \]

Positional Encodings

The Self-Attention mechanism, like the dot product, is permutation-invariant. This means that if the order of elements in the input changes, the same change is reflected in the order of the outputs—but the values themselves remain unaffected.

With Absolute Positional Encodings (APE), much more importance is given to the absolute positions of tokens rather than their relative distances. However, the relative distance between tokens carries more meaningful information than their absolute positions, which APE does not capture effectively.

Before Rotary Positional Embeddings (RoPE), positional information was directly added to token embeddings. This approach risked polluting the semantic representations of tokens.

RoPE

First introduced in the RoFormer paper. In RoPE, the angle between tokens encodes their relative distance within a sequence. Unlike APE, only the Query and Key matrices are rotated, preserving the semantic meaning of embeddings while capturing relative positions.

Instead of encoding absolute position directly by adding a vector we drew from sinusoidal functions of slowly decreasing frequencies, we encode the relative position by multiplying each pair with the rotation matrix.

Application

Suppose we have a matrix of Q shape ([5, 8]) (sequence_length, head_dimension) that we want to rotate:

\[ Q = \begin{bmatrix} 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 & 12 & 13 & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 & 21 & 22 & 23 \\ 24 & 25 & 26 & 27 & 28 & 29 & 30 & 31 \\ 32 & 33 & 34 & 35 & 36 & 37 & 38 & 39 \end{bmatrix} \]
NOTE: Embeddings are always floating point numbers ranging from -inf to +inf but to make the explanation more intuitive I have used sequential numbers

To "rotate" a pair of embeddings means to find a rotations matrix M' such that:

\[ Q_{\text{Rotated}} = M' \cdot \text{EmbeddingPair} \]

The Rotation Matrix is a matrix that rotates a vector and the angle by which it rotates depends on the position of the vector in the sequence.

\[ M = \begin{bmatrix} \cos(\theta) & \sin(\theta) \\ -\sin(\theta) & \cos(\theta) \end{bmatrix} \]

Where angle:

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
0
def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
1
def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
2
def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
3

Step 1: Divide the dimensions into embedding_dimension // 2

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
4

Step 2: Calculate Frequencies

i = 0:
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
1 i = 2:
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
2 i = 4:
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
3 i = 6:
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
4
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]

Step 3: Calculate Angles

Formula:
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
5

For each position, multiply by each frequency:

#### Position 0:

class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
6
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
7
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
8
class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
9

#### Position 1:

QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
0
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
1
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
2
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
3

#### Position 2:

QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
4
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
5
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
6
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
7

#### Position 3:

QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
8
QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
9
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
0
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
1

#### Position 4:

frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
2
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
3
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
4
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
5
def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
5

Step 4: Create Rotation Matrices

#### Angles for position 1:

frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
6

Since we have 4 pairs of dimensions, we need 4 rotation matrices:

#### Rotation Matrix Template:

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
6

#### For Position #0:

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
7

#### For Position #1:

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
8

#### And so on for Positions #2, #3 and #4

Step 5: Rotate the Q matrix by multiplying by rotation matrices

#### For Position #0:

def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated
9

#### For Position #1:

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
0

and so on for Positions #2, #3 and #4

Rotated Q

#### Q Original:

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
1

#### Q after RoPE:

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
2 Similarly the Keys matrix is also rotated and then both of them are used in the attention mechanism. NOTE: RoPE is applied after the Q and K are split into individual heads
class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
3

We calculate the RoPE parameters—sine and cosine values used to rotate the embeddings beforehand. Instead of recomputing these for each sequence, we precompute the angles up to the maximum context length (4096 in our case) and store them as PyTorch buffers.

PyTorch buffers are persistent tensors that behave like constants: they are not trainable parameters but are saved and moved with the model.

def calculate_sin_cos():
    i = torch.arange(0, config.head_dim, 2).float()
    freqs = 1.0 / (config.base ** (i / config.head_dim))
    positions = torch.arange(config.max_ctx_length).float()
    
    angles = positions[:, None] * freqs[None, :]
    # Duplicate angles to match full head_dim for proper RoPE rotation
    angles = torch.cat([angles, angles], dim=1)
    
    return torch.sin(angles), torch.cos(angles)
def rotate(x, sin, cos):
    batch_size, num_heads, seq_len, head_dim = x.shape
    
    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]
    
    rotation_matrix = torch.cat((-x2, x1), dim=-1)
    
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    
    x_rotated = (x * cos) + (rotation_matrix * sin)
    
    return x_rotated

RMS Norm

RMSNorm stands for Root Mean Square Normalization. It is a normalization technique designed to stabilize and accelerate training by normalizing inputs based on their root mean square (RMS) values.

Normalization helps us avoid scenarios like:

  • Exploding Gradients: When gradient values grow exponentially during backpropagation, causing training to become unstable
  • Vanishing Gradients: When gradient values become too small, preventing effective learning in deeper layers
class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
4

#### Where:

Gamma (Shift Parameter): a multiplicative parameter for scaling the distribution of the normalized values Beta (Scale Parameter): a additive parameter for shifting the distribution of the normalized values Epsilon (Stability Value): a small value (typically 1e-6) that prevents division by zero when the RMS approaches zero
class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta

SiLU Activation Function

Stands for Sigmoid Linear Unit. It is also known as Swish when parameterised. The function demonstrates non-monotonic behaviour, meaning it can decrease for certain input ranges before increasing again.

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
5

Where sigmoid is:

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
6

therefore:

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
7

Note: SiLU is the same as Swish when

frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
7 in
frequencies = [1.0000, 0.1000, 0.0100, 0.0010]
8

SiLU vs ReLU

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
8

The key difference between ReLU and SiLU lies in how they handle negative values. ReLU sets all negative inputs to zero, which can lead to the dying neuron problem, where certain neurons stop activating altogether during training. In contrast, SiLU (Sigmoid Linear Unit) allows small negative values to pass through, maintaining a smoother gradient flow. This helps reduce the risk of inactive neurons and can improve training stability.

SiLU vs GELU

class RMSNORM(nn.Module):
    def __init__(self, embed_dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
    
    def forward(self, x):
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        rms = torch.sqrt(variance + self.eps)
        
        return self.gamma * (x / rms) + self.beta
9

Both GELU and SiLU are non-monotonic and smooth activation functions that combine linear and non-linear components. However, SiLU is computationally more efficient. Interestingly, research has shown that GELU can be approximated as a scaled variant of SiLU:

class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output
0

Grouped Query Attention

Grouped Query Attention (GQA) is a variant of the standard Multi-Head Attention mechanism. GQA is an innovative variant of the standard Multi-Head Attention mechanism that strikes an optimal balance between computational efficiency and model performance. Instead of computing separate Key and Value matrices for every Query head, GQA groups multiple Query heads to share the same Key-Value pair. This reduces redundancy and improves efficiency, while still allowing diverse queries to attend to a common set of Keys and Values.

Multi Head Attention (MHA)

In standard MHA each head has its own Query, Key and Value Matrices. Although MHA is the most accurate, having as many K and V matrices as the number of heads introduces a huge memory overhead. This increases both training and inference costs in terms of storage, speed, and efficiency.

MQA

Each head has a separate Query matrix while there is only 1 pair of Key and Value matrix which are shared across all the heads. This reduces the number of parameters significantly. This makes MQA extremely fast and reduces the memory usage exponentially. However, this also comes with a tradeoff with accuracy due to the reduction in parameters.

GQA

GQA represents an elegant middle ground between MHA and MQA. Instead of each head having its own K/V (MHA) or all heads sharing one K/V (MQA), GQA groups multiple query heads to share the same K/V projections.

#### Benefits of GQA:

GQA delivers comprehensive improvements across multiple dimensions:

  • Reduced memory usage - Fewer K/V matrices mean lower memory requirements
  • Fewer parameters - Streamlined architecture without excessive redundancy
  • Reduced computation - Less matrix operations during forward and backward passes
  • Faster training - Improved throughput during model training
  • Faster inference - Quicker response times in production environments

GQA achieves these efficiency gains while maintaining minimal to low performance degradation compared to full MHA.

class GQA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w_q = nn.Linear(config.embed_dim, config.n_heads * config.head_dim, bias=False)
        self.w_k = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        self.w_v = nn.Linear(config.embed_dim, config.kv_heads * config.head_dim, bias=False)
        
        self.out_proj = nn.Linear(config.n_heads * config.head_dim, config.embed_dim, bias=False)
        
        # Q and K normalization layers
        self.q_norm = RMSNORM(config.n_heads * config.head_dim)
        self.k_norm = RMSNORM(config.kv_heads * config.head_dim)
    
    def forward(self, x, sin, cos):
        b, seq_len, _ = x.size()
        
        # Linear projections
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Apply normalization to Q and K
        Q = self.q_norm(Q)
        K = self.k_norm(K)
        
        # Reshape and transpose to (batch, heads, seq_len, head_dim)
        Q = Q.view(b, seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        K = K.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        V = V.view(b, seq_len, config.kv_heads, config.head_dim).transpose(1, 2)
        
        Q = rotate(Q, sin, cos)
        K = rotate(K, sin, cos)
        
        # Repeat K and V to match Q heads (for grouped attention)
        K = K.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        V = V.repeat_interleave(config.n_heads // config.kv_heads, dim=1)
        
        # Compute attention scores
        attention_scores = Q @ K.transpose(-2, -1) / (config.head_dim ** 0.5)
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions
        
        # Apply mask (set upper triangular to -inf)
        masked_attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax
        # softmax(-inf) = 0
        attention_weights = torch.softmax(masked_attention_scores, dim=-1)
        
        # Apply attention to values
        attention_output = attention_weights @ V
        
        # Reshape back to (batch, seq_len, n_heads * head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            b, seq_len, config.n_heads * config.head_dim
        )
        
        # Final projection
        output = self.out_proj(attention_output)
        
        return output

Finally: The Model

QWEN(
  (embeddings): Embedding(151936, 1024)
  (blocks): ModuleList(
    (0-27): 28 x Block(
      (rms_1): RMSNORM()
      (mgqa): GQA(
        (w_q): Linear(in_features=1024, out_features=2048, bias=False)
        (w_k): Linear(in_features=1024, out_features=1024, bias=False)
        (w_v): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNORM()
        (k_norm): RMSNORM()
      )
      (rms_2): RMSNORM()
      (ffn): FFN(
        (layer_1): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_2): Linear(in_features=1024, out_features=3072, bias=False)
        (layer_3): Linear(in_features=3072, out_features=1024, bias=False)
      )
    )
  )
  (rms_3): RMSNORM()
  (lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)

References

  1. Qwen3 Technical Report
  2. Sebastian Raschka's Qwen3 Notebook
  3. Designing Positional Encoding - Hugging Face
  4. RoFormer Paper
  5. GQA Paper
  6. Grouped Query Attention - KLU.ai