Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import json | |
| import base64 | |
| import io | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import RedirectResponse | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download, login | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| print("Authenticated with HF token.", flush=True) | |
| else: | |
| print("WARNING: HF_TOKEN not set β gated models will fail.", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # MedImageInsight β CLIP-style encoder for zero-shot label scoring | |
| # --------------------------------------------------------------------------- | |
| print("Downloading MedImageInsights repo...", flush=True) | |
| repo_path = snapshot_download("lion-ai/MedImageInsights") | |
| print(f"Downloaded to: {repo_path}", flush=True) | |
| sys.path.insert(0, repo_path) | |
| from medimageinsightmodel import MedImageInsight # noqa: E402 | |
| model_dir = os.path.join(repo_path, "2024.09.27") | |
| print("Loading MedImageInsight...", flush=True) | |
| classifier = MedImageInsight( | |
| model_dir=model_dir, | |
| vision_model_name="medimageinsigt-v1.0.0.pt", | |
| language_model_name="language_model.pth", | |
| ) | |
| classifier.load_model() | |
| print("MedImageInsight ready.", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # MedGemma β generative VLM for free-text image description | |
| # --------------------------------------------------------------------------- | |
| MEDGEMMA_ID = "google/medgemma-1.5-4b-it" | |
| print("Loading MedGemma processor...", flush=True) | |
| gemma_processor = AutoProcessor.from_pretrained(MEDGEMMA_ID) | |
| print("Loading MedGemma model (bfloat16)...", flush=True) | |
| gemma_model = AutoModelForImageTextToText.from_pretrained( | |
| MEDGEMMA_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| gemma_model.eval() | |
| print("MedGemma ready.", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="Medical Image Analysis API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def _encode_image(data: bytes) -> str: | |
| """Convert raw image bytes β base64 PNG string for MedImageInsight.""" | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.encodebytes(buf.getvalue()).decode("utf-8") | |
| def _scores_to_list(scores: dict) -> list: | |
| return [{"label": k, "score": round(float(v), 6)} for k, v in scores.items()] | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def root(): | |
| return RedirectResponse(url="/health") | |
| def health(): | |
| return {"status": "ok"} | |
| async def classify( | |
| image: UploadFile = File(...), | |
| labels: str = Form(...), | |
| ): | |
| """Zero-shot classification via MedImageInsight. Scores sum to ~1 (softmax).""" | |
| labels_list = json.loads(labels) | |
| img_b64 = _encode_image(await image.read()) | |
| results = classifier.predict([img_b64], labels_list, multilabel=False) | |
| return {"results": _scores_to_list(results[0])} | |
| async def multilabel( | |
| image: UploadFile = File(...), | |
| labels: str = Form(...), | |
| ): | |
| """Multi-label classification via MedImageInsight. Each score is independent (sigmoid).""" | |
| labels_list = json.loads(labels) | |
| img_b64 = _encode_image(await image.read()) | |
| results = classifier.predict([img_b64], labels_list, multilabel=True) | |
| return {"results": _scores_to_list(results[0])} | |
| async def describe( | |
| image: UploadFile = File(...), | |
| prompt: str = Form(default="Describe the medical findings visible in this image."), | |
| ): | |
| """Free-text image description via MedGemma 1.5-4B.""" | |
| img = Image.open(io.BytesIO(await image.read())).convert("RGB") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = gemma_processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(gemma_model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = gemma_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| generation = generation[0][input_len:] | |
| description = gemma_processor.decode(generation, skip_special_tokens=True) | |
| return {"description": description} | |