""" MiniMind Max2 Main Model Complete implementation of the Max2 language model. """ from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from configs.model_config import Max2Config, get_config from .components import Max2DecoderLayer, Max2RMSNorm class Max2Model(nn.Module): """Max2 Transformer Model - outputs raw hidden states.""" def __init__(self, config: Max2Config): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) mask = torch.triu(mask, diagonal=1) return mask.unsqueeze(0).unsqueeze(0) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]: batch_size, seq_len = input_ids.shape hidden_states = self.embed_tokens(input_ids) causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device) if attention_mask is not None: padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf") causal_mask = causal_mask + padding_mask next_cache = [] if use_cache else None total_aux_loss = torch.tensor(0.0, device=hidden_states.device) for idx, layer in enumerate(self.layers): past_kv = past_key_values[idx] if past_key_values else None hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache) if use_cache: next_cache.append(present_kv) total_aux_loss = total_aux_loss + aux_loss hidden_states = self.norm(hidden_states) return hidden_states, next_cache, total_aux_loss class Max2ForCausalLM(nn.Module): """Max2 Model with Language Modeling head for text generation.""" def __init__(self, config: Max2Config): super().__init__() self.config = config self.model = Max2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head.weight = self.model.embed_tokens.weight def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, ) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]: hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache) logits = self.lm_head(hidden_states).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) loss = loss + aux_loss return loss, logits, next_cache, aux_loss @torch.no_grad() def generate( self, input_ids: torch.LongTensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.95, do_sample: bool = True, ) -> torch.LongTensor: """Simple generation with top-k/top-p sampling.""" generated = input_ids past_key_values = None for _ in range(max_new_tokens): if past_key_values is None: _, logits, past_key_values, _ = self(generated, use_cache=True) else: _, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True) next_token_logits = logits[:, -1, :] / temperature if do_sample: if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if (next_token == self.config.eos_token_id).all(): break return generated # Backward compatibility aliases Mind2Model = Max2Model Mind2ForCausalLM = Max2ForCausalLM def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM: """Factory function to create a Max2 model.""" config = get_config(model_name) model = Max2ForCausalLM(config) return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model if __name__ == "__main__": for model_name in ["max2-nano", "max2-lite", "max2-pro"]: print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}") config = get_config(model_name) model = Max2ForCausalLM(config) total_params = sum(p.numel() for p in model.parameters()) print(f"Total Parameters: {total_params / 1e9:.3f}B") input_ids = torch.randint(0, config.vocab_size, (2, 128)) model.eval() with torch.no_grad(): loss, logits, _, aux_loss = model(input_ids, labels=input_ids) print(f"Logits shape: {logits.shape}") print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}") print("Forward pass successful!")