Spaces:
Runtime error
Runtime error
| """ | |
| MiniMind Max2 Main Model | |
| Complete implementation of the Max2 language model. | |
| """ | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from configs.model_config import Max2Config, get_config | |
| from .components import Max2DecoderLayer, Max2RMSNorm | |
| class Max2Model(nn.Module): | |
| """Max2 Transformer Model - outputs raw hidden states.""" | |
| def __init__(self, config: Max2Config): | |
| super().__init__() | |
| self.config = config | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) | |
| self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) | |
| self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.gradient_checkpointing = False | |
| self._init_weights() | |
| def _init_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: | |
| mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) | |
| mask = torch.triu(mask, diagonal=1) | |
| return mask.unsqueeze(0).unsqueeze(0) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]: | |
| batch_size, seq_len = input_ids.shape | |
| hidden_states = self.embed_tokens(input_ids) | |
| causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device) | |
| if attention_mask is not None: | |
| padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf") | |
| causal_mask = causal_mask + padding_mask | |
| next_cache = [] if use_cache else None | |
| total_aux_loss = torch.tensor(0.0, device=hidden_states.device) | |
| for idx, layer in enumerate(self.layers): | |
| past_kv = past_key_values[idx] if past_key_values else None | |
| hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache) | |
| if use_cache: | |
| next_cache.append(present_kv) | |
| total_aux_loss = total_aux_loss + aux_loss | |
| hidden_states = self.norm(hidden_states) | |
| return hidden_states, next_cache, total_aux_loss | |
| class Max2ForCausalLM(nn.Module): | |
| """Max2 Model with Language Modeling head for text generation.""" | |
| def __init__(self, config: Max2Config): | |
| super().__init__() | |
| self.config = config | |
| self.model = Max2Model(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.lm_head.weight = self.model.embed_tokens.weight | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]: | |
| hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache) | |
| logits = self.lm_head(hidden_states).float() | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) | |
| loss = loss + aux_loss | |
| return loss, logits, next_cache, aux_loss | |
| def generate( | |
| self, | |
| input_ids: torch.LongTensor, | |
| max_new_tokens: int = 100, | |
| temperature: float = 1.0, | |
| top_k: int = 50, | |
| top_p: float = 0.95, | |
| do_sample: bool = True, | |
| ) -> torch.LongTensor: | |
| """Simple generation with top-k/top-p sampling.""" | |
| generated = input_ids | |
| past_key_values = None | |
| for _ in range(max_new_tokens): | |
| if past_key_values is None: | |
| _, logits, past_key_values, _ = self(generated, use_cache=True) | |
| else: | |
| _, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True) | |
| next_token_logits = logits[:, -1, :] / temperature | |
| if do_sample: | |
| if top_k > 0: | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| else: | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
| generated = torch.cat([generated, next_token], dim=1) | |
| if (next_token == self.config.eos_token_id).all(): | |
| break | |
| return generated | |
| # Backward compatibility aliases | |
| Mind2Model = Max2Model | |
| Mind2ForCausalLM = Max2ForCausalLM | |
| def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM: | |
| """Factory function to create a Max2 model.""" | |
| config = get_config(model_name) | |
| model = Max2ForCausalLM(config) | |
| return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model | |
| if __name__ == "__main__": | |
| for model_name in ["max2-nano", "max2-lite", "max2-pro"]: | |
| print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}") | |
| config = get_config(model_name) | |
| model = Max2ForCausalLM(config) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Total Parameters: {total_params / 1e9:.3f}B") | |
| input_ids = torch.randint(0, config.vocab_size, (2, 128)) | |
| model.eval() | |
| with torch.no_grad(): | |
| loss, logits, _, aux_loss = model(input_ids, labels=input_ids) | |
| print(f"Logits shape: {logits.shape}") | |
| print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}") | |
| print("Forward pass successful!") | |