Rohit-Katkar2003 commited on
Commit
0c3b68e
·
verified ·
1 Parent(s): 65d484a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -34,7 +34,7 @@ SYSTEM_PROMPT = (
34
  def root():
35
  return {"message": "MobileLLM-Pro API is running!"}
36
 
37
- @app.get("/generate")
38
  def generate(prompt: str, max_tokens: int = 256):
39
  try:
40
  # Build messages with system instruction
@@ -70,5 +70,49 @@ def generate(prompt: str, max_tokens: int = 256):
70
 
71
  return {"input": prompt, "output": result.strip()}
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
  return {"error": str(e)}
 
34
  def root():
35
  return {"message": "MobileLLM-Pro API is running!"}
36
 
37
+ @app.get("/gen")
38
  def generate(prompt: str, max_tokens: int = 256):
39
  try:
40
  # Build messages with system instruction
 
70
 
71
  return {"input": prompt, "output": result.strip()}
72
 
73
+ except Exception as e:
74
+ return {"error": str(e)}
75
+
76
+ @app.post("/generate")
77
+ async def generate(request: Request):
78
+ try:
79
+ # Read JSON body from request
80
+ data = await request.json()
81
+ prompt = data.get("prompt", "")
82
+ max_tokens = data.get("max_tokens", 256)
83
+
84
+ # Build messages with system instruction
85
+ messages = [
86
+ {"role": "system", "content": SYSTEM_PROMPT},
87
+ {"role": "user", "content": prompt}
88
+ ]
89
+
90
+ # Apply chat template
91
+ inputs = tokenizer.apply_chat_template(
92
+ messages,
93
+ return_tensors="pt",
94
+ add_generation_prompt=True,
95
+ tokenize=True
96
+ ).to(device)
97
+
98
+ # Generate
99
+ with torch.no_grad():
100
+ outputs = model.generate(
101
+ input_ids=inputs,
102
+ max_new_tokens=max_tokens,
103
+ do_sample=True,
104
+ temperature=0.7,
105
+ top_p=0.95,
106
+ pad_token_id=tokenizer.pad_token_id,
107
+ eos_token_id=tokenizer.eos_token_id,
108
+ )
109
+
110
+ # Decode only the new part
111
+ input_len = inputs.shape[1]
112
+ generated_tokens = outputs[0][input_len:]
113
+ result = tokenizer.decode(generated_tokens, skip_special_tokens=True)
114
+
115
+ return {"input": prompt, "output": result.strip()}
116
+
117
  except Exception as e:
118
  return {"error": str(e)}