iraqigold commited on
Commit
8a35df3
·
verified ·
1 Parent(s): 50e19d2

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +49 -26
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
 
2
  import json
3
  import torch
4
- from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
 
9
  # Define the model ID
10
- # MedGemma 1.5 4B fits in ~8GB RAM using bfloat16, perfect for HF CPU Spaces
11
  MODEL_ID = "google/medgemma-1.5-4b-it"
12
 
13
  # Get huggingface token for gated models
@@ -15,7 +16,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
  app = FastAPI(
17
  title="MedGemma Radiology API",
18
- description="FastAPI service for analyzing radiology reports using MedGemma.",
19
  version="1.0.0"
20
  )
21
 
@@ -35,13 +36,12 @@ def load_model():
35
  global processor, model
36
  print(f"Loading processor and model {MODEL_ID}...")
37
  try:
38
- # Check deployment environment device
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
  processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
42
  model = AutoModelForImageTextToText.from_pretrained(
43
  MODEL_ID,
44
- torch_dtype=torch.bfloat16, # Optimized for reasonable RAM usage
45
  device_map=device,
46
  low_cpu_mem_usage=True,
47
  token=HF_TOKEN
@@ -52,46 +52,70 @@ def load_model():
52
  print(f"Error loading model: {e}")
53
  print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")
54
 
55
- class RadiologyCase(BaseModel):
56
- case_description: str
57
-
58
  class AnalysisResult(BaseModel):
59
  diagnosis: str
60
  recommendations: str
61
  urgency_level: str
62
- raw_response: str = None # Included internally for debugging
63
 
64
  # The "dماغ" or System Prompt
65
- SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل النصوص الواردة إليك والتي تصف نتائج صور الأشعة (X-ray, CT, MRI).
66
 
67
  قواعد العمل:
68
- 1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion).
69
- 2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية، التشخيص المحتمل، التوصيات).
70
- 3. الدقة: إذا كانت الحالة طارئة (مثل نزيف أو كسر مضاعف ابدأ بردك واجعل مستوى الحالة "حالة طارئة - Urgent".
71
  4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.
72
  5. اللغة: أجب باللغة العربية الطبية الرصينة.
73
 
74
  مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:
75
  {
76
- "diagnosis": "التشخيص المحتمل والنتائج الأساسية",
77
  "recommendations": "التوصيات والتحذير",
78
  "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"
79
  }"""
80
 
81
  @app.post("/analyze-radiology", response_model=AnalysisResult)
82
- async def analyze_report(case: RadiologyCase):
 
 
 
 
 
 
 
83
  if not model or not processor:
84
  raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
85
 
 
 
 
86
  try:
87
- # Combine System prompt with user case
88
- user_text = f"{SYSTEM_PROMPT}\n\nنص التقرير أو الحالة:\n{case.case_description}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  messages = [
90
  {
91
  "role": "user",
92
- "content": [
93
- {"type": "text", "text": user_text}
94
- ]
95
  }
96
  ]
97
 
@@ -103,22 +127,21 @@ async def analyze_report(case: RadiologyCase):
103
 
104
  input_len = inputs["input_ids"].shape[-1]
105
 
106
- # Generate with optimized settings
107
  with torch.inference_mode():
108
  generation = model.generate(
109
  **inputs,
110
  max_new_tokens=1024,
111
  do_sample=True,
112
- temperature=0.2, # Conservative temp for medical accuracy
113
  top_p=0.9
114
  )
115
- # Exclude the input prompt from generation output
116
  generation_output = generation[0][input_len:]
117
 
118
  decoded = processor.decode(generation_output, skip_special_tokens=True)
119
  raw_output = decoded.strip()
120
 
121
- # Helper: Clean out markdown block delimiters if model generated them
122
  clean_json = raw_output
123
  if clean_json.startswith("```json"):
124
  clean_json = clean_json.replace("```json", "", 1)
@@ -130,7 +153,6 @@ async def analyze_report(case: RadiologyCase):
130
  try:
131
  parsed_data = json.loads(clean_json)
132
  except json.JSONDecodeError:
133
- # Fallback if model doesn't strictly adhere to JSON outline
134
  is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
135
  parsed_data = {
136
  "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
@@ -153,5 +175,6 @@ def health_check():
153
  return {
154
  "status": "Online",
155
  "model": MODEL_ID,
156
- "message": "Welcome to MedGemma Radiology API"
 
157
  }
 
1
  import os
2
+ import io
3
  import json
4
  import torch
5
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
6
  from pydantic import BaseModel
7
  from transformers import AutoProcessor, AutoModelForImageTextToText
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ from PIL import Image
10
 
11
  # Define the model ID
 
12
  MODEL_ID = "google/medgemma-1.5-4b-it"
13
 
14
  # Get huggingface token for gated models
 
16
 
17
  app = FastAPI(
18
  title="MedGemma Radiology API",
19
+ description="FastAPI service for analyzing multimodal radiology cases (Image + Text) using MedGemma.",
20
  version="1.0.0"
21
  )
22
 
 
36
  global processor, model
37
  print(f"Loading processor and model {MODEL_ID}...")
38
  try:
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
 
41
  processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
42
  model = AutoModelForImageTextToText.from_pretrained(
43
  MODEL_ID,
44
+ torch_dtype=torch.bfloat16,
45
  device_map=device,
46
  low_cpu_mem_usage=True,
47
  token=HF_TOKEN
 
52
  print(f"Error loading model: {e}")
53
  print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")
54
 
 
 
 
55
  class AnalysisResult(BaseModel):
56
  diagnosis: str
57
  recommendations: str
58
  urgency_level: str
59
+ raw_response: str = None
60
 
61
  # The "dماغ" or System Prompt
62
+ SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل الصور والفحوصات الطبية المرفقة بالإضافة إلى النصوص الواردة والتي تصف حالة المريض.
63
 
64
  قواعد العمل:
65
+ 1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion) عند وصف الصورة.
66
+ 2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية للصورة، التشخيص المحتمل، التوصيات).
67
+ 3. الدقة: إذا كانت الحالة طارئة بناءً على الصورة (مثل كسر مضاعف أو استرواح الصدر)، اجعل مستوى الحالة "حالة طارئة - Urgent".
68
  4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.
69
  5. اللغة: أجب باللغة العربية الطبية الرصينة.
70
 
71
  مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:
72
  {
73
+ "diagnosis": "نتائج تحليل الصورة والتشخيص المحتمل",
74
  "recommendations": "التوصيات والتحذير",
75
  "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"
76
  }"""
77
 
78
  @app.post("/analyze-radiology", response_model=AnalysisResult)
79
+ async def analyze_report(
80
+ case_description: str = Form(""),
81
+ image: UploadFile = File(None)
82
+ ):
83
+ """
84
+ Analyzes a radiology case. Accepts an optional text description and an optional image (X-Ray, MRI, etc).
85
+ At least one of them must be provided.
86
+ """
87
  if not model or not processor:
88
  raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
89
 
90
+ if not case_description and not image:
91
+ raise HTTPException(status_code=400, detail="يجب إرفاق صورة أو كتابة وصف للحالة على الأقل.")
92
+
93
  try:
94
+ content = []
95
+
96
+ # 1. Process Image if provided
97
+ if image:
98
+ image_data = await image.read()
99
+ pil_image = Image.open(io.BytesIO(image_data)).convert("RGB")
100
+ content.append({"type": "image", "image": pil_image})
101
+
102
+ # 2. Process Text
103
+ user_text = SYSTEM_PROMPT + "\n\n"
104
+ if case_description:
105
+ user_text += f"وصف الحالة السريرية أو الأعراض:\n{case_description}\n\n"
106
+
107
+ if image:
108
+ user_text += "الرجاء تحليل الصورة الطبية المرفقة بناءً على القواعد أعلاه."
109
+ else:
110
+ user_text += "الرجاء تحليل الوصف الطبي أعلاه بناءً على القواعد أعلاه."
111
+
112
+ content.append({"type": "text", "text": user_text})
113
+
114
+ # 3. Create messages format
115
  messages = [
116
  {
117
  "role": "user",
118
+ "content": content
 
 
119
  }
120
  ]
121
 
 
127
 
128
  input_len = inputs["input_ids"].shape[-1]
129
 
130
+ # Generate
131
  with torch.inference_mode():
132
  generation = model.generate(
133
  **inputs,
134
  max_new_tokens=1024,
135
  do_sample=True,
136
+ temperature=0.2,
137
  top_p=0.9
138
  )
 
139
  generation_output = generation[0][input_len:]
140
 
141
  decoded = processor.decode(generation_output, skip_special_tokens=True)
142
  raw_output = decoded.strip()
143
 
144
+ # Clean JSON markdown blocks
145
  clean_json = raw_output
146
  if clean_json.startswith("```json"):
147
  clean_json = clean_json.replace("```json", "", 1)
 
153
  try:
154
  parsed_data = json.loads(clean_json)
155
  except json.JSONDecodeError:
 
156
  is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
157
  parsed_data = {
158
  "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
 
175
  return {
176
  "status": "Online",
177
  "model": MODEL_ID,
178
+ "vision_enabled": True,
179
+ "message": "Welcome to Multimodal MedGemma Radiology API"
180
  }
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  fastapi>=0.104.1
2
  uvicorn>=0.23.2
3
  pydantic>=2.4.2
 
4
  torch>=2.1.0
5
  transformers>=4.40.0
6
  accelerate>=0.29.3
 
1
  fastapi>=0.104.1
2
  uvicorn>=0.23.2
3
  pydantic>=2.4.2
4
+ python-multipart>=0.0.9
5
  torch>=2.1.0
6
  transformers>=4.40.0
7
  accelerate>=0.29.3