MiniMind-API / model_files /model /mind2_model.py
fariasultana's picture
MiniMind Max2 API - Gradio Interface
bd21ba5 verified
"""
MiniMind Max2 Main Model
Complete implementation of the Max2 language model.
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from configs.model_config import Max2Config, get_config
from .components import Max2DecoderLayer, Max2RMSNorm
class Max2Model(nn.Module):
"""Max2 Transformer Model - outputs raw hidden states."""
def __init__(self, config: Max2Config):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
mask = torch.triu(mask, diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]:
batch_size, seq_len = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device)
if attention_mask is not None:
padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf")
causal_mask = causal_mask + padding_mask
next_cache = [] if use_cache else None
total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
for idx, layer in enumerate(self.layers):
past_kv = past_key_values[idx] if past_key_values else None
hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache)
if use_cache:
next_cache.append(present_kv)
total_aux_loss = total_aux_loss + aux_loss
hidden_states = self.norm(hidden_states)
return hidden_states, next_cache, total_aux_loss
class Max2ForCausalLM(nn.Module):
"""Max2 Model with Language Modeling head for text generation."""
def __init__(self, config: Max2Config):
super().__init__()
self.config = config
self.model = Max2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]:
hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache)
logits = self.lm_head(hidden_states).float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
loss = loss + aux_loss
return loss, logits, next_cache, aux_loss
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
) -> torch.LongTensor:
"""Simple generation with top-k/top-p sampling."""
generated = input_ids
past_key_values = None
for _ in range(max_new_tokens):
if past_key_values is None:
_, logits, past_key_values, _ = self(generated, use_cache=True)
else:
_, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True)
next_token_logits = logits[:, -1, :] / temperature
if do_sample:
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == self.config.eos_token_id).all():
break
return generated
# Backward compatibility aliases
Mind2Model = Max2Model
Mind2ForCausalLM = Max2ForCausalLM
def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM:
"""Factory function to create a Max2 model."""
config = get_config(model_name)
model = Max2ForCausalLM(config)
return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model
if __name__ == "__main__":
for model_name in ["max2-nano", "max2-lite", "max2-pro"]:
print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}")
config = get_config(model_name)
model = Max2ForCausalLM(config)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params / 1e9:.3f}B")
input_ids = torch.randint(0, config.vocab_size, (2, 128))
model.eval()
with torch.no_grad():
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}")
print("Forward pass successful!")