medfusion-space / app.py
fokan's picture
Update app.py
bf46f16 verified
import gradio as gr
from transformers import AutoProcessor, AutoModel, AutoTokenizer, pipeline
from PIL import Image
import torch
# ---------------------------
# Load encoder (MedSigLIP)
# ---------------------------
ENCODER_ID = "fokan/medsiglip-448-int8"
encoder_processor = AutoProcessor.from_pretrained(ENCODER_ID)
encoder_model = AutoModel.from_pretrained(ENCODER_ID).eval()
# ---------------------------
# Load decoder (MedGemma)
# ---------------------------
DECODER_ID = "fokan/medgemma-4b-it-int8"
decoder = pipeline("text-generation", model=DECODER_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
# ---------------------------
# Core Function
# ---------------------------
@torch.no_grad()
def analyze_xray(image):
# Step 1: Encode the image into embedding
inputs = encoder_processor(images=image, text=["chest x-ray"], return_tensors="pt", padding=True)
outputs = encoder_model(**inputs)
if hasattr(outputs, "image_embeds"):
embedding = outputs.image_embeds[0]
elif hasattr(outputs, "last_hidden_state"):
embedding = outputs.last_hidden_state.mean(dim=1)[0]
else:
embedding = list(outputs.values())[0].mean(dim=1)[0]
embedding = embedding / embedding.norm()
# Step 2: Generate a short diagnostic report
prompt = (
"You are a radiologist. Analyze this chest X-ray embedding vector and describe any possible findings, "
"anomalies, or impressions as a short professional report.\n"
f"<embedding>{embedding[:256].tolist()}</embedding>"
)
report = decoder(prompt, max_new_tokens=180, temperature=0.8, top_p=0.9)[0]["generated_text"]
# Return both embedding preview + text
preview = embedding[:5].tolist()
return f"✅ Embedding (preview): {preview}\n\n🩺 **AI Radiology Report:**\n{report}"
# ---------------------------
# Gradio UI
# ---------------------------
title = "🩻 MedSigLIP → MedGemma Fusion"
desc = """
Upload an **X-ray image**, and this demo will:
1. Extract its visual embedding using `fokan/medsiglip-448-fp16-pruned20`.
2. Generate a **radiology-style report** using `fokan/medgemma-4b-it-fp16-pruned20`.
"""
demo = gr.Interface(
fn=analyze_xray,
inputs=gr.Image(type="pil", label="Upload X-ray Image"),
outputs=gr.Markdown(label="AI Report"),
title=title,
description=desc,
theme="gradio/soft",
)
if __name__ == "__main__":
demo.launch()