Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from pydantic import BaseModel | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| # Define the model ID | |
| MODEL_ID = "google/medgemma-1.5-4b-it" | |
| # Get huggingface token for gated models | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| app = FastAPI( | |
| title="MedGemma Radiology API", | |
| description="FastAPI service for analyzing multimodal radiology cases (Image + Text) using MedGemma.", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| processor = None | |
| model = None | |
| def load_model(): | |
| global processor, model | |
| print(f"Loading processor and model {MODEL_ID}...") | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN | |
| ) | |
| model.eval() | |
| print(f"Model loaded successfully on {device}.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.") | |
| class AnalysisResult(BaseModel): | |
| diagnosis: str | |
| recommendations: str | |
| urgency_level: str | |
| raw_response: str = None | |
| # The "dماغ" or System Prompt | |
| SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل الصور والفحوصات الطبية المرفقة بالإضافة إلى النصوص الواردة والتي تصف حالة المريض. | |
| قواعد العمل: | |
| 1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion) عند وصف الصورة. | |
| 2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية للصورة، التشخيص المحتمل، التوصيات). | |
| 3. الدقة: إذا كانت الحالة طارئة بناءً على الصورة (مثل كسر مضاعف أو استرواح الصدر)، اجعل مستوى الحالة "حالة طارئة - Urgent". | |
| 4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص. | |
| 5. اللغة: أجب باللغة العربية الطبية الرصينة. | |
| مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط: | |
| { | |
| "diagnosis": "نتائج تحليل الصورة والتشخيص المحتمل", | |
| "recommendations": "التوصيات والتحذير", | |
| "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)" | |
| }""" | |
| async def analyze_report( | |
| case_description: str = Form(""), | |
| image: UploadFile = File(None) | |
| ): | |
| """ | |
| Analyzes a radiology case. Accepts an optional text description and an optional image (X-Ray, MRI, etc). | |
| At least one of them must be provided. | |
| """ | |
| if not model or not processor: | |
| raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.") | |
| if not case_description and not image: | |
| raise HTTPException(status_code=400, detail="يجب إرفاق صورة أو كتابة وصف للحالة على الأقل.") | |
| try: | |
| content = [] | |
| # 1. Process Image if provided | |
| if image: | |
| image_data = await image.read() | |
| pil_image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| content.append({"type": "image", "image": pil_image}) | |
| # 2. Process Text | |
| user_text = SYSTEM_PROMPT + "\n\n" | |
| if case_description: | |
| user_text += f"وصف الحالة السريرية أو الأعراض:\n{case_description}\n\n" | |
| if image: | |
| user_text += "الرجاء تحليل الصورة الطبية المرفقة بناءً على القواعد أعلاه." | |
| else: | |
| user_text += "الرجاء تحليل الوصف الطبي أعلاه بناءً على القواعد أعلاه." | |
| content.append({"type": "text", "text": user_text}) | |
| # 3. Create messages format | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": content | |
| } | |
| ] | |
| # Format the prompt | |
| inputs = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # Generate | |
| with torch.inference_mode(): | |
| generation = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=0.9 | |
| ) | |
| generation_output = generation[0][input_len:] | |
| decoded = processor.decode(generation_output, skip_special_tokens=True) | |
| raw_output = decoded.strip() | |
| # Clean JSON markdown blocks | |
| clean_json = raw_output | |
| if clean_json.startswith("```json"): | |
| clean_json = clean_json.replace("```json", "", 1) | |
| if clean_json.endswith("```"): | |
| clean_json = clean_json[:-3] | |
| clean_json = clean_json.strip() | |
| # Parse JSON | |
| try: | |
| parsed_data = json.loads(clean_json) | |
| except json.JSONDecodeError: | |
| is_urgent = "Urgent" in raw_output or "طارئة" in raw_output | |
| parsed_data = { | |
| "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""), | |
| "recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.", | |
| "urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal" | |
| } | |
| return AnalysisResult( | |
| diagnosis=parsed_data.get("diagnosis", "غير محدد"), | |
| recommendations=parsed_data.get("recommendations", "غير محدد"), | |
| urgency_level=parsed_data.get("urgency_level", "غير محدد"), | |
| raw_response=raw_output | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") | |
| def health_check(): | |
| return { | |
| "status": "Online", | |
| "model": MODEL_ID, | |
| "vision_enabled": True, | |
| "message": "Welcome to Multimodal MedGemma Radiology API" | |
| } | |