import os import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM # ------------------------------- # ENVIRONMENT SETTINGS # ------------------------------- # Disable bitsandbytes if no GPU (CPU-only Spaces) os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["DISABLE_BITSANDBYTES"] = "1" # Hugging Face token login via environment variable from huggingface_hub import login login(os.environ.get("HF_TOKEN", "")) # ------------------------------- # MODEL CONFIG # ------------------------------- MODEL_NAME = "google/medgemma-4b-it" # or lighter if CPU only # Auto-detect device device = "cuda" if torch.cuda.is_available() else "cpu" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Load model safely model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, # safer on CPU device_map="auto" if device == "cuda" else None ) # ------------------------------- # SAFE GENERATION FUNCTION # ------------------------------- def medgemma_generate(prompt): if not prompt.strip(): return "Please enter a prompt." inputs = tokenizer(prompt, return_tensors="pt").to(device) try: outputs = model.generate( inputs["input_ids"], max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) text = tokenizer.decode(outputs[0], skip_special_tokens=True) return text except RuntimeError as e: return f"Generation failed: {str(e)}" # ------------------------------- # GRADIO INTERFACE # ------------------------------- demo = gr.Interface( fn=medgemma_generate, inputs=gr.Textbox( lines=4, placeholder="Enter your medical prompt...", label="Prompt" ), outputs=gr.Textbox( lines=4, max_lines=100, interactive=False, label="Generated Answer", show_copy_button=True, elem_classes="scroll-textbox" ), title="MedGemma Q&A", description="Ask medical questions (English). Safe generation config prevents NaNs on CPU.", css=""" .scroll-textbox textarea { overflow-y: auto !important; max-height: 600px !important; /* force scroll after ~100 lines */ resize: vertical !important; /* allow manual resizing */ } """ ) # Launch the app if __name__ == "__main__": demo.launch()