File size: 3,699 Bytes
c06d076
ca40ce9
 
65d484a
ca40ce9
 
 
65d484a
ca40ce9
 
 
 
 
 
65d484a
 
ca40ce9
65d484a
ca40ce9
 
 
 
 
65d484a
 
 
 
18a4e9d
 
 
 
 
ca40ce9
 
 
 
0c3b68e
18a4e9d
ca40ce9
65d484a
18a4e9d
 
 
 
65d484a
 
 
2452d2e
65d484a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a4e9d
2452d2e
65d484a
18a4e9d
65d484a
0c3b68e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca40ce9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI , Request
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

app = FastAPI(title="MobileLLM-Pro API", description="Public API for MobileLLM-Pro")

# --- Load model & tokenizer ---
MODEL_PATH = "/app/model"
print("🧠 Loading tokenizer and model...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map=None,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# Set pad_token if missing
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

SYSTEM_PROMPT = (
    "You are an expert AI assistant. Provide clear, accurate, and concise answers to the user's questions. "
    "Do not add extra commentary, disclaimers, or summaries unless asked. Answer directly."
)

@app.get("/")
def root():
    return {"message": "MobileLLM-Pro API is running!"}

@app.get("/gen")
def generate(prompt: str, max_tokens: int = 256):
    try:
        # Build messages with system instruction
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt}
        ]

        # ✅ Use apply_chat_template with return_tensors="pt" (like in your working code)
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
            tokenize=True  # explicit
        ).to(device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Decode only the new part
        input_len = inputs.shape[1]
        generated_tokens = outputs[0][input_len:]
        result = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        return {"input": prompt, "output": result.strip()}

    except Exception as e:
        return {"error": str(e)}

@app.post("/generate")
async def generate(request: Request):
    try:
        # Read JSON body from request
        data = await request.json()
        prompt = data.get("prompt", "")
        max_tokens = data.get("max_tokens", 256)

        # Build messages with system instruction
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt}
        ]

        # Apply chat template
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
            tokenize=True
        ).to(device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        # Decode only the new part
        input_len = inputs.shape[1]
        generated_tokens = outputs[0][input_len:]
        result = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        return {"input": prompt, "output": result.strip()}

    except Exception as e:
        return {"error": str(e)}