ecuenca40's picture
Update app.py
1bdc4ae verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch
model_id = "google/medgemma-4b-it"
# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto", # Requires 'accelerate'
)
processor = AutoProcessor.from_pretrained(model_id)
def generate_report(image, clinical_info):
if image is None:
return "Please upload a medical image."
# Create message list for chat-style input
user_content = []
if clinical_info:
user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"})
user_content.append({"type": "text", "text": "Please describe the medical image in a radiology report style."})
user_content.append({"type": "image", "image": image})
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist."}]},
{"role": "user", "content": user_content}
]
# Process input
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
output = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_p=0.9, top_k=50)
generated_ids = output[0]
decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True)
return decoded.strip()
# Gradio interface
gr.Interface(
fn=generate_report,
inputs=[
gr.Image(type="pil", label="Upload Medical Image (X-ray, etc)"),
gr.Textbox(lines=2, placeholder="e.g. Prior diagnosis: pneumonia. 65-year-old male with cough...", label="Optional Clinical Info")
],
outputs=gr.Textbox(label="Generated Radiology Report"),
title="🧠 MedGemma Radiology Report Generator",
description="Upload a medical image and optionally include clinical info (like prior findings or diagnosis). Powered by Google's MedGemma-4B model.",
allow_flagging="never"
).launch()