victorsconcious commited on
Commit
92199ae
·
verified ·
1 Parent(s): 863e2d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -10,7 +10,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
10
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
11
  os.environ["DISABLE_BITSANDBYTES"] = "1"
12
 
13
- # Hugging Face token login via env variable
14
  from huggingface_hub import login
15
  login(os.environ.get("HF_TOKEN", ""))
16
 
@@ -19,7 +19,7 @@ login(os.environ.get("HF_TOKEN", ""))
19
  # -------------------------------
20
  MODEL_NAME = "google/medgemma-4b-it" # or lighter if CPU only
21
 
22
- # auto-detect device
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  # Load tokenizer
@@ -28,8 +28,8 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
  # Load model safely
29
  model = AutoModelForCausalLM.from_pretrained(
30
  MODEL_NAME,
31
- torch_dtype=torch.float32, # safer than 4-bit on CPU
32
- device_map="auto" if device=="cuda" else None
33
  )
34
 
35
  # -------------------------------
@@ -59,11 +59,21 @@ def medgemma_generate(prompt):
59
  # -------------------------------
60
  demo = gr.Interface(
61
  fn=medgemma_generate,
62
- inputs=gr.Textbox(lines=4, placeholder="Enter your medical prompt..."),
63
- outputs="text",
 
 
 
 
 
 
 
 
 
64
  title="MedGemma Q&A",
65
  description="Ask medical questions (English). Safe generation config prevents NaNs on CPU."
66
  )
67
 
 
68
  if __name__ == "__main__":
69
  demo.launch()
 
10
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
11
  os.environ["DISABLE_BITSANDBYTES"] = "1"
12
 
13
+ # Hugging Face token login via environment variable
14
  from huggingface_hub import login
15
  login(os.environ.get("HF_TOKEN", ""))
16
 
 
19
  # -------------------------------
20
  MODEL_NAME = "google/medgemma-4b-it" # or lighter if CPU only
21
 
22
+ # Auto-detect device
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  # Load tokenizer
 
28
  # Load model safely
29
  model = AutoModelForCausalLM.from_pretrained(
30
  MODEL_NAME,
31
+ torch_dtype=torch.float32, # safer on CPU
32
+ device_map="auto" if device == "cuda" else None
33
  )
34
 
35
  # -------------------------------
 
59
  # -------------------------------
60
  demo = gr.Interface(
61
  fn=medgemma_generate,
62
+ inputs=gr.Textbox(
63
+ lines=4,
64
+ placeholder="Enter your medical prompt...",
65
+ label="Prompt"
66
+ ),
67
+ outputs=gr.Textbox(
68
+ lines=15, # start with 15 lines
69
+ max_lines=100, # auto-expand up to 100 lines
70
+ interactive=False,
71
+ label="Generated Answer"
72
+ ),
73
  title="MedGemma Q&A",
74
  description="Ask medical questions (English). Safe generation config prevents NaNs on CPU."
75
  )
76
 
77
+ # Launch the app
78
  if __name__ == "__main__":
79
  demo.launch()