Spaces:
Sleeping
Sleeping
File size: 2,538 Bytes
863e2d7 b8032fe 863e2d7 62a1fff 863e2d7 92199ae b8032fe 863e2d7 8fc3218 863e2d7 62a1fff 92199ae 863e2d7 dd4f303 b8032fe 863e2d7 dd4f303 863e2d7 dd4f303 863e2d7 92199ae dd4f303 863e2d7 b8032fe 863e2d7 b8032fe 863e2d7 b8032fe 863e2d7 92199ae 1578d01 92199ae 1578d01 5d30d79 1578d01 92199ae 863e2d7 1578d01 863e2d7 6b1ceee 1578d01 92199ae 6b1ceee b8032fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# -------------------------------
# ENVIRONMENT SETTINGS
# -------------------------------
# Disable bitsandbytes if no GPU (CPU-only Spaces)
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["DISABLE_BITSANDBYTES"] = "1"
# Hugging Face token login via environment variable
from huggingface_hub import login
login(os.environ.get("HF_TOKEN", ""))
# -------------------------------
# MODEL CONFIG
# -------------------------------
MODEL_NAME = "google/medgemma-4b-it" # or lighter if CPU only
# Auto-detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model safely
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # safer on CPU
device_map="auto" if device == "cuda" else None
)
# -------------------------------
# SAFE GENERATION FUNCTION
# -------------------------------
def medgemma_generate(prompt):
if not prompt.strip():
return "Please enter a prompt."
inputs = tokenizer(prompt, return_tensors="pt").to(device)
try:
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
except RuntimeError as e:
return f"Generation failed: {str(e)}"
# -------------------------------
# GRADIO INTERFACE
# -------------------------------
demo = gr.Interface(
fn=medgemma_generate,
inputs=gr.Textbox(
lines=4,
placeholder="Enter your medical prompt...",
label="Prompt"
),
outputs=gr.Textbox(
lines=4,
max_lines=100,
interactive=False,
label="Generated Answer",
show_copy_button=True,
elem_classes="scroll-textbox"
),
title="MedGemma Q&A",
description="Ask medical questions (English). Safe generation config prevents NaNs on CPU.",
css="""
.scroll-textbox textarea {
overflow-y: auto !important;
max-height: 600px !important; /* force scroll after ~100 lines */
resize: vertical !important; /* allow manual resizing */
}
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|