victorsconcious commited on
Commit
b8032fe
·
verified ·
1 Parent(s): 8fc3218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -35
app.py CHANGED
@@ -1,50 +1,49 @@
1
- import os
2
- import gradio as gr
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
5
- from huggingface_hub import login
6
 
7
  import os
8
- os.environ["BITSANDBYTES_NOWELCOME"] = "1"
9
- os.environ["DISABLE_BITSANDBYTES"] = "1"
10
-
11
 
12
- # --- Authenticate with HF token (from Spaces Secrets) ---
13
- login(os.environ["HF_TOKEN"])
14
 
15
- # --- Model setup ---
16
- MODEL_ID = "google/medgemma-4b-it"
17
 
18
- # 4-bit quantization config
19
- bnb_config = BitsAndBytesConfig(
20
- load_in_4bit=True,
21
- bnb_4bit_use_double_quant=True,
22
- bnb_4bit_quant_type="nf4",
23
- bnb_4bit_compute_dtype=torch.bfloat16
24
- )
25
 
26
- # Load model + tokenizer with quantization
27
  model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_ID,
29
- quantization_config=bnb_config,
30
- device_map="auto"
 
31
  )
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
33
 
34
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
35
 
36
- # --- Gradio app ---
37
  def medgemma_chat(prompt):
38
- outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
39
- return outputs[0]["generated_text"]
40
-
41
- demo = gr.Interface(
42
- fn=medgemma_chat,
43
- inputs=gr.Textbox(label="Enter medical question", lines=4, placeholder="e.g. What are symptoms of malaria?"),
44
- outputs=gr.Textbox(label="MedGemma Response"),
45
- title="🧠 MedGemma (4-bit Quantized)",
46
- description="Ask medical questions (research/demo use only). Running in 4-bit quantized mode for efficiency."
47
- )
 
 
 
 
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
3
+ import gradio as gr
4
 
5
  import os
6
+ from huggingface_hub import login
 
 
7
 
8
+ login(os.environ["HF_TOKEN"]) # use the token with gated repo access
 
9
 
10
+ MODEL_ID = "google/med-gemma-2b"
 
11
 
12
+ # Load tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
14
 
15
+ # Load model with 4-bit quantization (works on CPU)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
+ device_map="cpu",
19
+ torch_dtype=torch.float32, # stay safe from NaN in CPU mode
20
+ load_in_4bit=True # quantize
21
  )
 
22
 
23
+ # Wrap in a pipeline
24
+ pipe = TextGenerationPipeline(model=model, tokenizer=tokenizer, device=-1)
25
 
26
+ # Safe generation function
27
  def medgemma_chat(prompt):
28
+ try:
29
+ output = pipe(
30
+ prompt,
31
+ max_new_tokens=200,
32
+ temperature=1.0, # stable
33
+ top_p=0.9,
34
+ do_sample=True
35
+ )
36
+ return output[0]["generated_text"]
37
+ except Exception as e:
38
+ return f"⚠️ Error: {str(e)}"
39
+
40
+ # Gradio UI
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("# 🩺 MedGemma (Quantized, CPU-safe)")
43
+ inp = gr.Textbox(label="Enter patient info", placeholder="Example: Patient has fever and cough...")
44
+ out = gr.Textbox(label="Model Output")
45
+ btn = gr.Button("Generate")
46
+ btn.click(medgemma_chat, inp, out)
47
 
48
  if __name__ == "__main__":
49
+ demo.launch()