fariasultana commited on
Commit
bd21ba5
·
verified ·
1 Parent(s): c1384b2

MiniMind Max2 API - Gradio Interface

Browse files
README.md CHANGED
@@ -1,12 +1,65 @@
1
  ---
2
- title: MiniMind API
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MiniMind Max2 API
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ tags:
12
+ - text-generation
13
+ - moe
14
+ - fastapi
15
+ - language-model
16
  ---
17
 
18
+ # 🧠 MiniMind Max2 API
19
+
20
+ **Tiny Model, Powerful Experience** - An efficient language model API with FastAPI backend.
21
+
22
+ ## Features
23
+
24
+ - **Mixture of Experts (MoE)**: Only 25% of parameters activated per token
25
+ - **Grouped Query Attention**: 4:1 ratio for memory efficiency
26
+ - **FastAPI Backend**: RESTful API with automatic docs
27
+ - **Gradio Interface**: Interactive UI for testing
28
+
29
+ ## API Endpoints
30
+
31
+ | Endpoint | Method | Description |
32
+ |----------|--------|-------------|
33
+ | `/docs` | GET | Swagger UI documentation |
34
+ | `/generate` | POST | Generate text from prompt |
35
+ | `/model-info` | GET | Get model architecture info |
36
+ | `/health` | GET | Health check |
37
+ | `/gradio` | GET | Interactive Gradio interface |
38
+
39
+ ## Example Usage
40
+
41
+ ```python
42
+ import requests
43
+
44
+ response = requests.post(
45
+ "https://your-space.hf.space/generate",
46
+ json={
47
+ "prompt": "Once upon a time",
48
+ "max_new_tokens": 100,
49
+ "temperature": 0.8
50
+ }
51
+ )
52
+ print(response.json()["generated_text"])
53
+ ```
54
+
55
+ ## Model Variants
56
+
57
+ | Model | Total Params | Active Params | Target |
58
+ |-------|-------------|---------------|--------|
59
+ | max2-nano | 500M | 125M | IoT, Mobile |
60
+ | max2-lite | 1.5B | 375M | Mobile, Tablet |
61
+ | max2-pro | 3B | 750M | Desktop |
62
+
63
+ ## License
64
+
65
+ Apache 2.0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMind Max2 - Gradio Space
3
+ A lightweight, efficient language model with MoE architecture.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add model files to path
11
+ sys.path.insert(0, str(Path(__file__).parent / "model_files"))
12
+
13
+ import torch
14
+ import gradio as gr
15
+
16
+ # Configuration
17
+ MODEL_NAME = os.getenv("MODEL_NAME", "max2-nano")
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
20
+
21
+ # Global model
22
+ model = None
23
+ config = None
24
+
25
+ def load_model():
26
+ """Load the Max2 model."""
27
+ global model, config
28
+
29
+ from configs.model_config import get_config, estimate_params
30
+ from model import Max2ForCausalLM
31
+
32
+ print(f"🔄 Loading {MODEL_NAME} on {DEVICE}...")
33
+ config = get_config(MODEL_NAME)
34
+ model = Max2ForCausalLM(config)
35
+ model = model.to(device=DEVICE, dtype=DTYPE)
36
+ model.eval()
37
+
38
+ params = estimate_params(config)
39
+ print(f"✅ Model loaded: {params['total_params_b']:.3f}B total, {params['active_params_b']:.3f}B active")
40
+ return model, config
41
+
42
+ def generate_text(prompt, max_tokens, temperature, top_k, top_p):
43
+ """Generate text from prompt."""
44
+ global model, config
45
+
46
+ if model is None:
47
+ load_model()
48
+
49
+ if not prompt.strip():
50
+ return "Please enter a prompt."
51
+
52
+ try:
53
+ # Simple character-level tokenization (demo purposes)
54
+ # In production, use SentencePiece or similar tokenizer
55
+ prompt_ids = [ord(c) % config.vocab_size for c in prompt]
56
+ input_ids = torch.tensor([prompt_ids], device=DEVICE)
57
+
58
+ with torch.no_grad():
59
+ output_ids = model.generate(
60
+ input_ids,
61
+ max_new_tokens=int(max_tokens),
62
+ temperature=temperature,
63
+ top_k=int(top_k),
64
+ top_p=top_p,
65
+ do_sample=True,
66
+ )
67
+
68
+ # Decode generated tokens
69
+ generated_ids = output_ids[0, len(prompt_ids):].tolist()
70
+ generated_text = "".join([chr(min(max(i, 32), 126)) for i in generated_ids])
71
+
72
+ return prompt + generated_text
73
+
74
+ except Exception as e:
75
+ return f"Error: {str(e)}"
76
+
77
+ def get_model_info():
78
+ """Get model information."""
79
+ global model, config
80
+
81
+ if model is None:
82
+ load_model()
83
+
84
+ from configs.model_config import estimate_params
85
+ params = estimate_params(config)
86
+
87
+ return f"""
88
+ ## Model: {config.model_name}
89
+
90
+ | Property | Value |
91
+ |----------|-------|
92
+ | Total Parameters | {params['total_params_b']:.3f}B |
93
+ | Active Parameters | {params['active_params_b']:.3f}B |
94
+ | Activation Ratio | {params['activation_ratio']:.1%} |
95
+ | Device | {DEVICE} |
96
+ | Num Experts | {config.num_experts} |
97
+ | Experts per Token | {config.num_experts_per_tok} |
98
+ | Max Context | {config.max_position_embeddings} |
99
+ """
100
+
101
+ # Create Gradio interface
102
+ with gr.Blocks(title="MiniMind Max2", theme=gr.themes.Soft()) as demo:
103
+ gr.Markdown("""
104
+ # 🧠 MiniMind Max2
105
+
106
+ **Tiny Model, Powerful Experience** - An efficient language model with Mixture of Experts (MoE) architecture.
107
+ Only 25% of parameters are activated per token for efficient inference.
108
+
109
+ > ⚠️ **Note**: This demo uses character-level tokenization for simplicity.
110
+ > For production use, integrate a proper tokenizer (SentencePiece, etc.).
111
+ """)
112
+
113
+ with gr.Tabs():
114
+ with gr.TabItem("🚀 Generate"):
115
+ with gr.Row():
116
+ with gr.Column(scale=2):
117
+ prompt_input = gr.Textbox(
118
+ label="Prompt",
119
+ placeholder="Enter your prompt here...",
120
+ lines=4,
121
+ value="Once upon a time"
122
+ )
123
+
124
+ with gr.Row():
125
+ max_tokens = gr.Slider(
126
+ minimum=10, maximum=256, value=100, step=10,
127
+ label="Max New Tokens"
128
+ )
129
+ temperature = gr.Slider(
130
+ minimum=0.1, maximum=2.0, value=0.8, step=0.1,
131
+ label="Temperature"
132
+ )
133
+
134
+ with gr.Row():
135
+ top_k = gr.Slider(
136
+ minimum=1, maximum=100, value=50, step=1,
137
+ label="Top-K"
138
+ )
139
+ top_p = gr.Slider(
140
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05,
141
+ label="Top-P"
142
+ )
143
+
144
+ generate_btn = gr.Button("🎯 Generate", variant="primary")
145
+
146
+ with gr.Column(scale=2):
147
+ output_text = gr.Textbox(
148
+ label="Generated Text",
149
+ lines=12,
150
+ show_copy_button=True
151
+ )
152
+
153
+ generate_btn.click(
154
+ fn=generate_text,
155
+ inputs=[prompt_input, max_tokens, temperature, top_k, top_p],
156
+ outputs=output_text
157
+ )
158
+
159
+ gr.Examples(
160
+ examples=[
161
+ ["Once upon a time", 100, 0.8, 50, 0.9],
162
+ ["The quick brown fox", 50, 0.7, 40, 0.95],
163
+ ["In a galaxy far away", 150, 1.0, 60, 0.85],
164
+ ["def fibonacci(n):", 80, 0.6, 30, 0.9],
165
+ ],
166
+ inputs=[prompt_input, max_tokens, temperature, top_k, top_p],
167
+ )
168
+
169
+ with gr.TabItem("ℹ️ Model Info"):
170
+ info_btn = gr.Button("📊 Load Model Info")
171
+ info_output = gr.Markdown()
172
+ info_btn.click(fn=get_model_info, outputs=info_output)
173
+
174
+ gr.Markdown("""
175
+ ---
176
+
177
+ ### 🔧 Architecture
178
+ - **MoE**: 8 experts, top-2 routing (25% activation)
179
+ - **GQA**: Grouped Query Attention (4:1 ratio)
180
+ - **RoPE**: Rotary Position Embeddings
181
+ - **SwiGLU**: Improved activation function
182
+
183
+ ### 📦 Model Variants
184
+ | Model | Total | Active | Target |
185
+ |-------|-------|--------|--------|
186
+ | max2-nano | 500M | 125M | IoT/Mobile |
187
+ | max2-lite | 1.5B | 375M | Mobile/Tablet |
188
+ | max2-pro | 3B | 750M | Desktop |
189
+
190
+ ---
191
+
192
+ **[Model Repository](https://huggingface.co/fariasultana/MiniMind)** |
193
+ **License**: Apache 2.0
194
+ """)
195
+
196
+ # Load model on startup
197
+ try:
198
+ load_model()
199
+ except Exception as e:
200
+ print(f"Model will load on first request: {e}")
201
+
202
+ if __name__ == "__main__":
203
+ demo.launch(server_name="0.0.0.0", server_port=7860)
model_files/configs/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MiniMind Max2 Configuration Module"""
2
+ from .model_config import Max2Config, get_config, estimate_params, MAX2_CONFIGS
3
+
4
+ # Backward compatibility
5
+ Mind2Config = Max2Config
6
+ MIND2_CONFIGS = MAX2_CONFIGS
7
+
8
+ __all__ = [
9
+ "Max2Config",
10
+ "Mind2Config",
11
+ "get_config",
12
+ "estimate_params",
13
+ "MAX2_CONFIGS",
14
+ "MIND2_CONFIGS",
15
+ ]
model_files/configs/model_config.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMind Max2 Model Configuration
3
+ Inspired by MiniMax M2's efficient activated parameters design
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any
8
+
9
+
10
+ @dataclass
11
+ class Max2Config:
12
+ """Configuration for MiniMind Max2 models."""
13
+
14
+ # Model identification
15
+ model_name: str = "max2-lite"
16
+ model_version: str = "1.0.0"
17
+
18
+ # Architecture dimensions
19
+ hidden_size: int = 1536
20
+ intermediate_size: int = 4096
21
+ num_hidden_layers: int = 24
22
+ num_attention_heads: int = 12
23
+ num_key_value_heads: int = 3 # GQA ratio 4:1
24
+
25
+ # Vocabulary and embeddings
26
+ vocab_size: int = 32000
27
+ max_position_embeddings: int = 8192
28
+ rope_theta: float = 10000.0
29
+
30
+ # MoE (Mixture of Experts) configuration
31
+ use_moe: bool = True
32
+ num_experts: int = 8
33
+ num_experts_per_tok: int = 2 # Only 25% activation
34
+ expert_hidden_size: int = 1024
35
+ router_aux_loss_coef: float = 0.01
36
+
37
+ # Normalization and activation
38
+ rms_norm_eps: float = 1e-6
39
+ hidden_act: str = "silu"
40
+
41
+ # Regularization
42
+ hidden_dropout: float = 0.0
43
+ attention_dropout: float = 0.0
44
+
45
+ # Special tokens
46
+ pad_token_id: int = 0
47
+ bos_token_id: int = 1
48
+ eos_token_id: int = 2
49
+
50
+ # Initialization
51
+ initializer_range: float = 0.02
52
+
53
+ # Memory optimization
54
+ use_cache: bool = True
55
+ use_flash_attention: bool = True
56
+ gradient_checkpointing: bool = False
57
+
58
+ def to_dict(self) -> Dict[str, Any]:
59
+ return {k: v for k, v in self.__dict__.items()}
60
+
61
+ @classmethod
62
+ def from_dict(cls, config_dict: Dict[str, Any]) -> "Max2Config":
63
+ return cls(**{k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__})
64
+
65
+
66
+ # Predefined model configurations
67
+ MAX2_CONFIGS = {
68
+ "max2-nano": Max2Config(
69
+ model_name="max2-nano",
70
+ hidden_size=768,
71
+ intermediate_size=2048,
72
+ num_hidden_layers=12,
73
+ num_attention_heads=12,
74
+ num_key_value_heads=3,
75
+ num_experts=4,
76
+ num_experts_per_tok=1,
77
+ expert_hidden_size=512,
78
+ max_position_embeddings=4096,
79
+ ),
80
+ "max2-lite": Max2Config(
81
+ model_name="max2-lite",
82
+ hidden_size=1536,
83
+ intermediate_size=4096,
84
+ num_hidden_layers=24,
85
+ num_attention_heads=12,
86
+ num_key_value_heads=3,
87
+ num_experts=8,
88
+ num_experts_per_tok=2,
89
+ expert_hidden_size=1024,
90
+ max_position_embeddings=8192,
91
+ ),
92
+ "max2-pro": Max2Config(
93
+ model_name="max2-pro",
94
+ hidden_size=2560,
95
+ intermediate_size=6912,
96
+ num_hidden_layers=32,
97
+ num_attention_heads=20,
98
+ num_key_value_heads=4,
99
+ num_experts=8,
100
+ num_experts_per_tok=2,
101
+ expert_hidden_size=1728,
102
+ max_position_embeddings=16384,
103
+ ),
104
+ }
105
+
106
+ # Aliases for backward compatibility
107
+ Mind2Config = Max2Config
108
+ MIND2_CONFIGS = MAX2_CONFIGS
109
+
110
+
111
+ def get_config(model_name: str) -> Max2Config:
112
+ """Get predefined configuration by name."""
113
+ if model_name not in MAX2_CONFIGS:
114
+ raise ValueError(f"Unknown model: {model_name}. Available: {list(MAX2_CONFIGS.keys())}")
115
+ return MAX2_CONFIGS[model_name]
116
+
117
+
118
+ def estimate_params(config: Max2Config) -> dict:
119
+ """Estimate parameter counts for a configuration."""
120
+ embed_params = config.vocab_size * config.hidden_size
121
+ head_dim = config.hidden_size // config.num_attention_heads
122
+
123
+ # Attention parameters per layer (GQA)
124
+ q_params = config.hidden_size * config.hidden_size
125
+ kv_params = 2 * config.hidden_size * (config.num_key_value_heads * head_dim)
126
+ o_params = config.hidden_size * config.hidden_size
127
+ attn_params_per_layer = q_params + kv_params + o_params
128
+
129
+ # MoE FFN parameters per layer
130
+ if config.use_moe:
131
+ router_params = config.hidden_size * config.num_experts
132
+ expert_params = 3 * config.hidden_size * config.expert_hidden_size
133
+ ffn_params_per_layer = router_params + (config.num_experts * expert_params)
134
+ active_ffn_params = router_params + (config.num_experts_per_tok * expert_params)
135
+ else:
136
+ ffn_params_per_layer = 3 * config.hidden_size * config.intermediate_size
137
+ active_ffn_params = ffn_params_per_layer
138
+
139
+ norm_params_per_layer = 2 * config.hidden_size
140
+ layer_params = attn_params_per_layer + ffn_params_per_layer + norm_params_per_layer
141
+ active_layer_params = attn_params_per_layer + active_ffn_params + norm_params_per_layer
142
+
143
+ total_params = embed_params + (config.num_hidden_layers * layer_params) + embed_params
144
+ active_params = embed_params + (config.num_hidden_layers * active_layer_params) + embed_params
145
+
146
+ return {
147
+ "total_params": total_params,
148
+ "active_params": active_params,
149
+ "activation_ratio": active_params / total_params,
150
+ "total_params_b": total_params / 1e9,
151
+ "active_params_b": active_params / 1e9,
152
+ "estimated_size_fp16_gb": (total_params * 2) / (1024**3),
153
+ "estimated_size_int4_gb": (total_params * 0.5) / (1024**3),
154
+ }
model_files/model/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMind Max2 Model Package
3
+ A lightweight, efficient language model designed for edge deployment.
4
+ """
5
+
6
+ from .mind2_model import (
7
+ Max2ForCausalLM,
8
+ Max2Model,
9
+ Mind2ForCausalLM,
10
+ Mind2Model,
11
+ create_model
12
+ )
13
+ from .components import (
14
+ Max2Attention,
15
+ Max2MoE,
16
+ Max2DecoderLayer,
17
+ Max2RMSNorm,
18
+ Max2RotaryEmbedding,
19
+ Max2MLP,
20
+ Max2Expert,
21
+ # Backward compatibility
22
+ Mind2Attention,
23
+ Mind2MoE,
24
+ Mind2DecoderLayer,
25
+ Mind2RMSNorm,
26
+ Mind2RotaryEmbedding,
27
+ )
28
+
29
+ __all__ = [
30
+ # Max2 (primary)
31
+ "Max2ForCausalLM",
32
+ "Max2Model",
33
+ "Max2Attention",
34
+ "Max2MoE",
35
+ "Max2DecoderLayer",
36
+ "Max2RMSNorm",
37
+ "Max2RotaryEmbedding",
38
+ "Max2MLP",
39
+ "Max2Expert",
40
+ # Mind2 (backward compatibility)
41
+ "Mind2ForCausalLM",
42
+ "Mind2Model",
43
+ "Mind2Attention",
44
+ "Mind2MoE",
45
+ "Mind2DecoderLayer",
46
+ "Mind2RMSNorm",
47
+ "Mind2RotaryEmbedding",
48
+ # Factory
49
+ "create_model",
50
+ ]
51
+
52
+ __version__ = "1.0.0"
model_files/model/components.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMind Max2 Model Components
3
+ Core building blocks: RMSNorm, RoPE, GQA Attention, MoE
4
+ """
5
+
6
+ import math
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ import sys
13
+ from pathlib import Path
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+ from configs.model_config import Max2Config
16
+
17
+
18
+ class Max2RMSNorm(nn.Module):
19
+ """Root Mean Square Layer Normalization (faster than LayerNorm)."""
20
+
21
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(hidden_size))
24
+ self.eps = eps
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ input_dtype = x.dtype
28
+ x = x.to(torch.float32)
29
+ variance = x.pow(2).mean(-1, keepdim=True)
30
+ x = x * torch.rsqrt(variance + self.eps)
31
+ return self.weight * x.to(input_dtype)
32
+
33
+
34
+ class Max2RotaryEmbedding(nn.Module):
35
+ """Rotary Position Embedding (RoPE) for efficient position encoding."""
36
+
37
+ def __init__(self, dim: int, max_position_embeddings: int = 8192, base: float = 10000.0):
38
+ super().__init__()
39
+ self.dim = dim
40
+ self.max_position_embeddings = max_position_embeddings
41
+ self.base = base
42
+
43
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
44
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
45
+ self._set_cos_sin_cache(max_position_embeddings)
46
+
47
+ def _set_cos_sin_cache(self, seq_len: int):
48
+ self.max_seq_len_cached = seq_len
49
+ t = torch.arange(seq_len, dtype=torch.float32)
50
+ freqs = torch.outer(t, self.inv_freq)
51
+ emb = torch.cat((freqs, freqs), dim=-1)
52
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
53
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
54
+
55
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ if seq_len > self.max_seq_len_cached:
57
+ self._set_cos_sin_cache(seq_len)
58
+ return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)
59
+
60
+
61
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
62
+ """Rotate half the hidden dims of the input."""
63
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
64
+ return torch.cat((-x2, x1), dim=-1)
65
+
66
+
67
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ """Apply rotary position embeddings to query and key tensors."""
69
+ cos = cos.unsqueeze(0).unsqueeze(0)
70
+ sin = sin.unsqueeze(0).unsqueeze(0)
71
+ q_embed = (q * cos) + (rotate_half(q) * sin)
72
+ k_embed = (k * cos) + (rotate_half(k) * sin)
73
+ return q_embed, k_embed
74
+
75
+
76
+ class Max2Attention(nn.Module):
77
+ """Grouped Query Attention (GQA) - fewer KV heads than Q heads for memory efficiency."""
78
+
79
+ def __init__(self, config: Max2Config, layer_idx: int):
80
+ super().__init__()
81
+ self.config = config
82
+ self.layer_idx = layer_idx
83
+ self.hidden_size = config.hidden_size
84
+ self.num_heads = config.num_attention_heads
85
+ self.num_kv_heads = config.num_key_value_heads
86
+ self.head_dim = self.hidden_size // self.num_heads
87
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
88
+
89
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
90
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
91
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
92
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
93
+
94
+ self.rotary_emb = Max2RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
95
+ self.attention_dropout = config.attention_dropout
96
+
97
+ def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
98
+ if n_rep == 1:
99
+ return hidden_states
100
+ bs, num_kv_heads, seq_len, head_dim = hidden_states.shape
101
+ hidden_states = hidden_states[:, :, None, :, :].expand(bs, num_kv_heads, n_rep, seq_len, head_dim)
102
+ return hidden_states.reshape(bs, num_kv_heads * n_rep, seq_len, head_dim)
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
109
+ use_cache: bool = False,
110
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
111
+ batch_size, seq_len, _ = hidden_states.shape
112
+
113
+ query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
114
+ key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
115
+ value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
116
+
117
+ cos, sin = self.rotary_emb(value_states, seq_len)
118
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
119
+
120
+ if past_key_value is not None:
121
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
122
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
123
+
124
+ past_key_value = (key_states, value_states) if use_cache else None
125
+
126
+ key_states = self._repeat_kv(key_states, self.num_key_value_groups)
127
+ value_states = self._repeat_kv(value_states, self.num_key_value_groups)
128
+
129
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
130
+ if attention_mask is not None:
131
+ attn_weights = attn_weights + attention_mask
132
+
133
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
134
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
135
+ attn_output = torch.matmul(attn_weights, value_states)
136
+
137
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
138
+ attn_output = self.o_proj(attn_output)
139
+
140
+ return attn_output, past_key_value
141
+
142
+
143
+ class Max2MLP(nn.Module):
144
+ """SwiGLU Feed-Forward Network."""
145
+
146
+ def __init__(self, hidden_size: int, intermediate_size: int):
147
+ super().__init__()
148
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
149
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
150
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
154
+
155
+
156
+ class Max2Expert(nn.Module):
157
+ """Single expert in the Mixture of Experts layer."""
158
+
159
+ def __init__(self, hidden_size: int, expert_hidden_size: int):
160
+ super().__init__()
161
+ self.mlp = Max2MLP(hidden_size, expert_hidden_size)
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ return self.mlp(x)
165
+
166
+
167
+ class Max2MoE(nn.Module):
168
+ """
169
+ Mixture of Experts (MoE) layer.
170
+ Efficient parameter activation - only top-k experts are used per token.
171
+ Inspired by MiniMax M2's efficient activated parameters design.
172
+ """
173
+
174
+ def __init__(self, config: Max2Config):
175
+ super().__init__()
176
+ self.hidden_size = config.hidden_size
177
+ self.num_experts = config.num_experts
178
+ self.num_experts_per_tok = config.num_experts_per_tok
179
+ self.expert_hidden_size = config.expert_hidden_size
180
+
181
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
182
+ self.experts = nn.ModuleList([
183
+ Max2Expert(self.hidden_size, self.expert_hidden_size)
184
+ for _ in range(self.num_experts)
185
+ ])
186
+ self.router_aux_loss_coef = config.router_aux_loss_coef
187
+
188
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ batch_size, seq_len, hidden_dim = hidden_states.shape
190
+ hidden_states_flat = hidden_states.view(-1, hidden_dim)
191
+
192
+ router_logits = self.gate(hidden_states_flat)
193
+ router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
194
+
195
+ router_weights, selected_experts = torch.topk(router_probs, self.num_experts_per_tok, dim=-1)
196
+ router_weights = router_weights.to(hidden_states.dtype)
197
+ router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
198
+
199
+ final_hidden_states = torch.zeros_like(hidden_states_flat)
200
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
201
+
202
+ for expert_idx in range(self.num_experts):
203
+ expert = self.experts[expert_idx]
204
+ for top_k_idx in range(self.num_experts_per_tok):
205
+ token_indices = expert_mask[expert_idx, top_k_idx].nonzero(as_tuple=True)[0]
206
+ if token_indices.numel() > 0:
207
+ expert_input = hidden_states_flat[token_indices]
208
+ expert_output = expert(expert_input)
209
+ weights = router_weights[token_indices, top_k_idx].unsqueeze(-1)
210
+ final_hidden_states[token_indices] += weights * expert_output
211
+
212
+ final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
213
+
214
+ num_tokens = router_probs.shape[0]
215
+ expert_mask_float = F.one_hot(selected_experts, num_classes=self.num_experts).float()
216
+ tokens_per_expert = expert_mask_float.sum(dim=(0, 1)) / num_tokens
217
+ router_prob_per_expert = router_probs.mean(dim=0)
218
+ aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum() * self.router_aux_loss_coef
219
+
220
+ return final_hidden_states, aux_loss
221
+
222
+
223
+ class Max2DecoderLayer(nn.Module):
224
+ """Single transformer decoder layer with GQA attention and MoE FFN."""
225
+
226
+ def __init__(self, config: Max2Config, layer_idx: int):
227
+ super().__init__()
228
+ self.hidden_size = config.hidden_size
229
+ self.self_attn = Max2Attention(config, layer_idx)
230
+
231
+ if config.use_moe:
232
+ self.mlp = Max2MoE(config)
233
+ self.use_moe = True
234
+ else:
235
+ self.mlp = Max2MLP(config.hidden_size, config.intermediate_size)
236
+ self.use_moe = False
237
+
238
+ self.input_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239
+ self.post_attention_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
246
+ use_cache: bool = False,
247
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
248
+ residual = hidden_states
249
+ hidden_states = self.input_layernorm(hidden_states)
250
+ hidden_states, present_key_value = self.self_attn(hidden_states, attention_mask, past_key_value, use_cache)
251
+ hidden_states = residual + hidden_states
252
+
253
+ residual = hidden_states
254
+ hidden_states = self.post_attention_layernorm(hidden_states)
255
+
256
+ if self.use_moe:
257
+ hidden_states, aux_loss = self.mlp(hidden_states)
258
+ else:
259
+ hidden_states = self.mlp(hidden_states)
260
+ aux_loss = torch.tensor(0.0, device=hidden_states.device)
261
+
262
+ hidden_states = residual + hidden_states
263
+
264
+ return hidden_states, present_key_value, aux_loss
265
+
266
+
267
+ # Backward compatibility aliases
268
+ Mind2RMSNorm = Max2RMSNorm
269
+ Mind2RotaryEmbedding = Max2RotaryEmbedding
270
+ Mind2Attention = Max2Attention
271
+ Mind2MLP = Max2MLP
272
+ Mind2Expert = Max2Expert
273
+ Mind2MoE = Max2MoE
274
+ Mind2DecoderLayer = Max2DecoderLayer
model_files/model/mind2_model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MiniMind Max2 Main Model
3
+ Complete implementation of the Max2 language model.
4
+ """
5
+
6
+ from typing import List, Optional, Tuple
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import CrossEntropyLoss
11
+
12
+ import sys
13
+ from pathlib import Path
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+ from configs.model_config import Max2Config, get_config
16
+ from .components import Max2DecoderLayer, Max2RMSNorm
17
+
18
+
19
+ class Max2Model(nn.Module):
20
+ """Max2 Transformer Model - outputs raw hidden states."""
21
+
22
+ def __init__(self, config: Max2Config):
23
+ super().__init__()
24
+ self.config = config
25
+ self.padding_idx = config.pad_token_id
26
+ self.vocab_size = config.vocab_size
27
+
28
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
29
+ self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
30
+ self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
31
+
32
+ self.gradient_checkpointing = False
33
+ self._init_weights()
34
+
35
+ def _init_weights(self):
36
+ for module in self.modules():
37
+ if isinstance(module, nn.Linear):
38
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
39
+ if module.bias is not None:
40
+ module.bias.data.zero_()
41
+ elif isinstance(module, nn.Embedding):
42
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
43
+
44
+ def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
45
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
46
+ mask = torch.triu(mask, diagonal=1)
47
+ return mask.unsqueeze(0).unsqueeze(0)
48
+
49
+ def forward(
50
+ self,
51
+ input_ids: torch.LongTensor,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
54
+ use_cache: bool = False,
55
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]:
56
+ batch_size, seq_len = input_ids.shape
57
+ hidden_states = self.embed_tokens(input_ids)
58
+
59
+ causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device)
60
+ if attention_mask is not None:
61
+ padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf")
62
+ causal_mask = causal_mask + padding_mask
63
+
64
+ next_cache = [] if use_cache else None
65
+ total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
66
+
67
+ for idx, layer in enumerate(self.layers):
68
+ past_kv = past_key_values[idx] if past_key_values else None
69
+ hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache)
70
+
71
+ if use_cache:
72
+ next_cache.append(present_kv)
73
+ total_aux_loss = total_aux_loss + aux_loss
74
+
75
+ hidden_states = self.norm(hidden_states)
76
+ return hidden_states, next_cache, total_aux_loss
77
+
78
+
79
+ class Max2ForCausalLM(nn.Module):
80
+ """Max2 Model with Language Modeling head for text generation."""
81
+
82
+ def __init__(self, config: Max2Config):
83
+ super().__init__()
84
+ self.config = config
85
+ self.model = Max2Model(config)
86
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
87
+ self.lm_head.weight = self.model.embed_tokens.weight
88
+
89
+ def forward(
90
+ self,
91
+ input_ids: torch.LongTensor,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ labels: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
95
+ use_cache: bool = False,
96
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]:
97
+ hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache)
98
+ logits = self.lm_head(hidden_states).float()
99
+
100
+ loss = None
101
+ if labels is not None:
102
+ shift_logits = logits[..., :-1, :].contiguous()
103
+ shift_labels = labels[..., 1:].contiguous()
104
+ loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
105
+ loss = loss + aux_loss
106
+
107
+ return loss, logits, next_cache, aux_loss
108
+
109
+ @torch.no_grad()
110
+ def generate(
111
+ self,
112
+ input_ids: torch.LongTensor,
113
+ max_new_tokens: int = 100,
114
+ temperature: float = 1.0,
115
+ top_k: int = 50,
116
+ top_p: float = 0.95,
117
+ do_sample: bool = True,
118
+ ) -> torch.LongTensor:
119
+ """Simple generation with top-k/top-p sampling."""
120
+ generated = input_ids
121
+ past_key_values = None
122
+
123
+ for _ in range(max_new_tokens):
124
+ if past_key_values is None:
125
+ _, logits, past_key_values, _ = self(generated, use_cache=True)
126
+ else:
127
+ _, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True)
128
+
129
+ next_token_logits = logits[:, -1, :] / temperature
130
+
131
+ if do_sample:
132
+ if top_k > 0:
133
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
134
+ next_token_logits[indices_to_remove] = float('-inf')
135
+
136
+ if top_p < 1.0:
137
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
138
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
139
+ sorted_indices_to_remove = cumulative_probs > top_p
140
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
141
+ sorted_indices_to_remove[..., 0] = 0
142
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
143
+ next_token_logits[indices_to_remove] = float('-inf')
144
+
145
+ probs = F.softmax(next_token_logits, dim=-1)
146
+ next_token = torch.multinomial(probs, num_samples=1)
147
+ else:
148
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
149
+
150
+ generated = torch.cat([generated, next_token], dim=1)
151
+
152
+ if (next_token == self.config.eos_token_id).all():
153
+ break
154
+
155
+ return generated
156
+
157
+
158
+ # Backward compatibility aliases
159
+ Mind2Model = Max2Model
160
+ Mind2ForCausalLM = Max2ForCausalLM
161
+
162
+
163
+ def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM:
164
+ """Factory function to create a Max2 model."""
165
+ config = get_config(model_name)
166
+ model = Max2ForCausalLM(config)
167
+ return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model
168
+
169
+
170
+ if __name__ == "__main__":
171
+ for model_name in ["max2-nano", "max2-lite", "max2-pro"]:
172
+ print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}")
173
+ config = get_config(model_name)
174
+ model = Max2ForCausalLM(config)
175
+
176
+ total_params = sum(p.numel() for p in model.parameters())
177
+ print(f"Total Parameters: {total_params / 1e9:.3f}B")
178
+
179
+ input_ids = torch.randint(0, config.vocab_size, (2, 128))
180
+ model.eval()
181
+ with torch.no_grad():
182
+ loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
183
+ print(f"Logits shape: {logits.shape}")
184
+ print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}")
185
+ print("Forward pass successful!")
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0