Doctor / app.py
victorsconcious's picture
Update app.py
5d30d79 verified
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# -------------------------------
# ENVIRONMENT SETTINGS
# -------------------------------
# Disable bitsandbytes if no GPU (CPU-only Spaces)
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["DISABLE_BITSANDBYTES"] = "1"
# Hugging Face token login via environment variable
from huggingface_hub import login
login(os.environ.get("HF_TOKEN", ""))
# -------------------------------
# MODEL CONFIG
# -------------------------------
MODEL_NAME = "google/medgemma-4b-it" # or lighter if CPU only
# Auto-detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model safely
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # safer on CPU
device_map="auto" if device == "cuda" else None
)
# -------------------------------
# SAFE GENERATION FUNCTION
# -------------------------------
def medgemma_generate(prompt):
if not prompt.strip():
return "Please enter a prompt."
inputs = tokenizer(prompt, return_tensors="pt").to(device)
try:
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
except RuntimeError as e:
return f"Generation failed: {str(e)}"
# -------------------------------
# GRADIO INTERFACE
# -------------------------------
demo = gr.Interface(
fn=medgemma_generate,
inputs=gr.Textbox(
lines=4,
placeholder="Enter your medical prompt...",
label="Prompt"
),
outputs=gr.Textbox(
lines=4,
max_lines=100,
interactive=False,
label="Generated Answer",
show_copy_button=True,
elem_classes="scroll-textbox"
),
title="MedGemma Q&A",
description="Ask medical questions (English). Safe generation config prevents NaNs on CPU.",
css="""
.scroll-textbox textarea {
overflow-y: auto !important;
max-height: 600px !important; /* force scroll after ~100 lines */
resize: vertical !important; /* allow manual resizing */
}
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()