""" MiniMind Max2 Model Components Core building blocks: RMSNorm, RoPE, GQA Attention, MoE """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from configs.model_config import Max2Config class Max2RMSNorm(nn.Module): """Root Mean Square Layer Normalization (faster than LayerNorm).""" def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x.to(input_dtype) class Max2RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE) for efficient position encoding.""" def __init__(self, dim: int, max_position_embeddings: int = 8192, base: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len: int): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype) def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input.""" x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embeddings to query and key tensors.""" cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Max2Attention(nn.Module): """Grouped Query Attention (GQA) - fewer KV heads than Q heads for memory efficiency.""" def __init__(self, config: Max2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Max2RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) self.attention_dropout = config.attention_dropout def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: if n_rep == 1: return hidden_states bs, num_kv_heads, seq_len, head_dim = hidden_states.shape hidden_states = hidden_states[:, :, None, :, :].expand(bs, num_kv_heads, n_rep, seq_len, head_dim) return hidden_states.reshape(bs, num_kv_heads * n_rep, seq_len, head_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: batch_size, seq_len, _ = hidden_states.shape query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None key_states = self._repeat_kv(key_states, self.num_key_value_groups) value_states = self._repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, past_key_value class Max2MLP(nn.Module): """SwiGLU Feed-Forward Network.""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class Max2Expert(nn.Module): """Single expert in the Mixture of Experts layer.""" def __init__(self, hidden_size: int, expert_hidden_size: int): super().__init__() self.mlp = Max2MLP(hidden_size, expert_hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class Max2MoE(nn.Module): """ Mixture of Experts (MoE) layer. Efficient parameter activation - only top-k experts are used per token. Inspired by MiniMax M2's efficient activated parameters design. """ def __init__(self, config: Max2Config): super().__init__() self.hidden_size = config.hidden_size self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.expert_hidden_size = config.expert_hidden_size self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False) self.experts = nn.ModuleList([ Max2Expert(self.hidden_size, self.expert_hidden_size) for _ in range(self.num_experts) ]) self.router_aux_loss_coef = config.router_aux_loss_coef def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_flat) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) router_weights, selected_experts = torch.topk(router_probs, self.num_experts_per_tok, dim=-1) router_weights = router_weights.to(hidden_states.dtype) router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) final_hidden_states = torch.zeros_like(hidden_states_flat) expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) for expert_idx in range(self.num_experts): expert = self.experts[expert_idx] for top_k_idx in range(self.num_experts_per_tok): token_indices = expert_mask[expert_idx, top_k_idx].nonzero(as_tuple=True)[0] if token_indices.numel() > 0: expert_input = hidden_states_flat[token_indices] expert_output = expert(expert_input) weights = router_weights[token_indices, top_k_idx].unsqueeze(-1) final_hidden_states[token_indices] += weights * expert_output final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) num_tokens = router_probs.shape[0] expert_mask_float = F.one_hot(selected_experts, num_classes=self.num_experts).float() tokens_per_expert = expert_mask_float.sum(dim=(0, 1)) / num_tokens router_prob_per_expert = router_probs.mean(dim=0) aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum() * self.router_aux_loss_coef return final_hidden_states, aux_loss class Max2DecoderLayer(nn.Module): """Single transformer decoder layer with GQA attention and MoE FFN.""" def __init__(self, config: Max2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Max2Attention(config, layer_idx) if config.use_moe: self.mlp = Max2MoE(config) self.use_moe = True else: self.mlp = Max2MLP(config.hidden_size, config.intermediate_size) self.use_moe = False self.input_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, present_key_value = self.self_attn(hidden_states, attention_mask, past_key_value, use_cache) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.use_moe: hidden_states, aux_loss = self.mlp(hidden_states) else: hidden_states = self.mlp(hidden_states) aux_loss = torch.tensor(0.0, device=hidden_states.device) hidden_states = residual + hidden_states return hidden_states, present_key_value, aux_loss # Backward compatibility aliases Mind2RMSNorm = Max2RMSNorm Mind2RotaryEmbedding = Max2RotaryEmbedding Mind2Attention = Max2Attention Mind2MLP = Max2MLP Mind2Expert = Max2Expert Mind2MoE = Max2MoE Mind2DecoderLayer = Max2DecoderLayer