File size: 2,538 Bytes
863e2d7
b8032fe
863e2d7
 
62a1fff
863e2d7
 
 
 
 
 
 
92199ae
b8032fe
863e2d7
8fc3218
863e2d7
 
 
 
62a1fff
92199ae
863e2d7
dd4f303
b8032fe
863e2d7
dd4f303
863e2d7
dd4f303
863e2d7
92199ae
 
dd4f303
 
863e2d7
 
 
 
 
 
 
b8032fe
863e2d7
 
 
 
 
b8032fe
863e2d7
 
b8032fe
863e2d7
 
 
 
 
 
 
 
 
 
92199ae
 
 
 
 
 
1578d01
 
92199ae
1578d01
5d30d79
1578d01
92199ae
863e2d7
1578d01
 
 
 
 
 
 
 
863e2d7
6b1ceee
1578d01
92199ae
6b1ceee
b8032fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# -------------------------------
# ENVIRONMENT SETTINGS
# -------------------------------
# Disable bitsandbytes if no GPU (CPU-only Spaces)
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["DISABLE_BITSANDBYTES"] = "1"

# Hugging Face token login via environment variable
from huggingface_hub import login
login(os.environ.get("HF_TOKEN", ""))

# -------------------------------
# MODEL CONFIG
# -------------------------------
MODEL_NAME = "google/medgemma-4b-it"  # or lighter if CPU only

# Auto-detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load model safely
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # safer on CPU
    device_map="auto" if device == "cuda" else None
)

# -------------------------------
# SAFE GENERATION FUNCTION
# -------------------------------
def medgemma_generate(prompt):
    if not prompt.strip():
        return "Please enter a prompt."
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    try:
        outputs = model.generate(
            inputs["input_ids"],
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return text
    except RuntimeError as e:
        return f"Generation failed: {str(e)}"

# -------------------------------
# GRADIO INTERFACE
# -------------------------------
demo = gr.Interface(
    fn=medgemma_generate,
    inputs=gr.Textbox(
        lines=4,
        placeholder="Enter your medical prompt...",
        label="Prompt"
    ),
    outputs=gr.Textbox(
        lines=4,
        max_lines=100,
        interactive=False,
        label="Generated Answer",
        show_copy_button=True,
        elem_classes="scroll-textbox"
    ),
    title="MedGemma Q&A",
    description="Ask medical questions (English). Safe generation config prevents NaNs on CPU.",
    css="""
    .scroll-textbox textarea {
        overflow-y: auto !important;
        max-height: 600px !important;  /* force scroll after ~100 lines */
        resize: vertical !important;   /* allow manual resizing */
    }
    """
)


# Launch the app
if __name__ == "__main__":
    demo.launch()