image-analysis / app.py
Matt Grannell
Fix HF auth: use login() to authenticate globally from HF_TOKEN secret
e9f8bf5
import sys
import os
import json
import base64
import io
import torch
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from PIL import Image
from huggingface_hub import snapshot_download, login
from transformers import AutoProcessor, AutoModelForImageTextToText
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("Authenticated with HF token.", flush=True)
else:
print("WARNING: HF_TOKEN not set β€” gated models will fail.", flush=True)
# ---------------------------------------------------------------------------
# MedImageInsight β€” CLIP-style encoder for zero-shot label scoring
# ---------------------------------------------------------------------------
print("Downloading MedImageInsights repo...", flush=True)
repo_path = snapshot_download("lion-ai/MedImageInsights")
print(f"Downloaded to: {repo_path}", flush=True)
sys.path.insert(0, repo_path)
from medimageinsightmodel import MedImageInsight # noqa: E402
model_dir = os.path.join(repo_path, "2024.09.27")
print("Loading MedImageInsight...", flush=True)
classifier = MedImageInsight(
model_dir=model_dir,
vision_model_name="medimageinsigt-v1.0.0.pt",
language_model_name="language_model.pth",
)
classifier.load_model()
print("MedImageInsight ready.", flush=True)
# ---------------------------------------------------------------------------
# MedGemma β€” generative VLM for free-text image description
# ---------------------------------------------------------------------------
MEDGEMMA_ID = "google/medgemma-1.5-4b-it"
print("Loading MedGemma processor...", flush=True)
gemma_processor = AutoProcessor.from_pretrained(MEDGEMMA_ID)
print("Loading MedGemma model (bfloat16)...", flush=True)
gemma_model = AutoModelForImageTextToText.from_pretrained(
MEDGEMMA_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
gemma_model.eval()
print("MedGemma ready.", flush=True)
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(title="Medical Image Analysis API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def _encode_image(data: bytes) -> str:
"""Convert raw image bytes β†’ base64 PNG string for MedImageInsight."""
img = Image.open(io.BytesIO(data)).convert("RGB")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.encodebytes(buf.getvalue()).decode("utf-8")
def _scores_to_list(scores: dict) -> list:
return [{"label": k, "score": round(float(v), 6)} for k, v in scores.items()]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/")
def root():
return RedirectResponse(url="/health")
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/classify")
async def classify(
image: UploadFile = File(...),
labels: str = Form(...),
):
"""Zero-shot classification via MedImageInsight. Scores sum to ~1 (softmax)."""
labels_list = json.loads(labels)
img_b64 = _encode_image(await image.read())
results = classifier.predict([img_b64], labels_list, multilabel=False)
return {"results": _scores_to_list(results[0])}
@app.post("/multilabel")
async def multilabel(
image: UploadFile = File(...),
labels: str = Form(...),
):
"""Multi-label classification via MedImageInsight. Each score is independent (sigmoid)."""
labels_list = json.loads(labels)
img_b64 = _encode_image(await image.read())
results = classifier.predict([img_b64], labels_list, multilabel=True)
return {"results": _scores_to_list(results[0])}
@app.post("/describe")
async def describe(
image: UploadFile = File(...),
prompt: str = Form(default="Describe the medical findings visible in this image."),
):
"""Free-text image description via MedGemma 1.5-4B."""
img = Image.open(io.BytesIO(await image.read())).convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt},
],
}
]
inputs = gemma_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(gemma_model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = gemma_model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
generation = generation[0][input_len:]
description = gemma_processor.decode(generation, skip_special_tokens=True)
return {"description": description}