| from fastapi import FastAPI, File, UploadFile, HTTPException, Body |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
| import cv2 |
| import numpy as np |
| import tempfile |
| import os |
| from io import BytesIO |
| from PIL import Image |
| import uvicorn |
| import traceback |
| import json |
| from typing import List, Dict, Optional |
| import re |
|
|
| |
| |
| |
| |
| try: |
| from dotenv import load_dotenv |
| _here = os.path.dirname(os.path.abspath(__file__)) |
| |
| for _env_path in [ |
| os.path.join(_here, ".env"), |
| os.path.join(_here, "..", ".env"), |
| ]: |
| if os.path.isfile(_env_path): |
| load_dotenv(_env_path) |
| print(f"✅ Loaded .env from: {os.path.abspath(_env_path)}") |
| break |
| else: |
| print("⚠️ No .env file found. Set GEMINI_API_KEY in your environment.") |
| except ImportError: |
| pass |
|
|
|
|
| try: |
| from .inference import infer_aw_contour, analyze_frame, analyze_video_frame, infer_cervix_bbox |
| except ImportError: |
| from inference import infer_aw_contour, analyze_frame, analyze_video_frame, infer_cervix_bbox |
|
|
| |
| try: |
| import google.generativeai as genai |
| GEMINI_AVAILABLE = True |
| except ImportError: |
| GEMINI_AVAILABLE = False |
| print("⚠️ google-generativeai not installed. LLM endpoints will be unavailable.") |
|
|
| app = FastAPI(title="Pathora Colposcopy API", version="1.0.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("VITE_GEMINI_API_KEY") |
| if GEMINI_AVAILABLE and GEMINI_API_KEY: |
| try: |
| genai.configure(api_key=GEMINI_API_KEY) |
| print("✅ Gemini AI configured successfully") |
| except Exception as e: |
| print(f"⚠️ Failed to configure Gemini: {e}") |
| GEMINI_AVAILABLE = False |
| elif GEMINI_AVAILABLE: |
| print("⚠️ GEMINI_API_KEY not found in environment variables") |
|
|
|
|
| def get_supported_gemini_models() -> List[str]: |
| """Return model names that support generateContent for this API key.""" |
| if not GEMINI_AVAILABLE or not GEMINI_API_KEY: |
| return [] |
|
|
| discovered: List[str] = [] |
| try: |
| for model in genai.list_models(): |
| methods = getattr(model, "supported_generation_methods", []) or [] |
| if "generateContent" not in methods: |
| continue |
|
|
| raw_name = getattr(model, "name", "") |
| if not raw_name: |
| continue |
|
|
| discovered.append(raw_name) |
| |
| if raw_name.startswith("models/"): |
| discovered.append(raw_name[len("models/"):]) |
| except Exception as e: |
| print(f"⚠️ Could not list Gemini models: {e}") |
| return [] |
|
|
| |
| unique_models: List[str] = [] |
| seen = set() |
| for name in discovered: |
| if name not in seen: |
| unique_models.append(name) |
| seen.add(name) |
| return unique_models |
|
|
|
|
| |
| QUOTA_BLOCKED_MODELS: set[str] = set() |
|
|
|
|
| def get_ordered_model_candidates(available_models: List[str]) -> List[str]: |
| """Order models by preference and exclude quota-blocked models.""" |
| preferred_models = [ |
| |
| "models/gemini-2.5-flash", |
| "gemini-2.5-flash", |
| "models/gemini-flash-latest", |
| "gemini-flash-latest", |
| "models/gemini-2.5-flash-lite", |
| "gemini-2.5-flash-lite", |
| "models/gemini-flash-lite-latest", |
| "gemini-flash-lite-latest", |
| |
| "models/gemini-2.0-flash", |
| "gemini-2.0-flash", |
| "models/gemini-2.0-flash-lite", |
| "gemini-2.0-flash-lite", |
| "models/gemini-1.5-flash", |
| "gemini-1.5-flash", |
| "models/gemini-1.5-pro", |
| "gemini-1.5-pro", |
| "models/gemini-pro-latest", |
| "gemini-pro-latest", |
| "models/gemini-pro", |
| "gemini-pro", |
| ] |
|
|
| available = [m for m in available_models if m not in QUOTA_BLOCKED_MODELS] |
| ordered = [m for m in preferred_models if m in available] |
| ordered.extend(m for m in available if m not in ordered) |
| return ordered |
|
|
| |
| class ChatMessage(BaseModel): |
| role: str |
| text: str |
|
|
| class ChatRequest(BaseModel): |
| message: str |
| history: List[ChatMessage] = [] |
| system_prompt: Optional[str] = None |
|
|
| class ReportGenerationRequest(BaseModel): |
| patient_data: Dict |
| exam_findings: Dict |
| images: Optional[List[str]] = [] |
| system_prompt: Optional[str] = None |
|
|
|
|
| class SPAStaticFiles(StaticFiles): |
| async def get_response(self, path: str, scope): |
| response = await super().get_response(path, scope) |
| if response.status_code == 404: |
| return await super().get_response("index.html", scope) |
| return response |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| available_models = get_supported_gemini_models() |
| |
| return { |
| "status": "healthy", |
| "service": "Pathora Colposcopy API", |
| "ai_models": { |
| "acetowhite_model": "loaded", |
| "cervix_model": "loaded" |
| }, |
| "llm": { |
| "gemini_available": GEMINI_AVAILABLE, |
| "api_key_configured": bool(GEMINI_API_KEY), |
| "available_models": available_models |
| } |
| } |
|
|
|
|
| @app.get("/api/health") |
| async def api_health_check(): |
| """Health check endpoint under /api for HF Spaces compatibility.""" |
| return await health_check() |
|
|
|
|
| @app.post("/api/chat") |
| async def chat_endpoint(request: ChatRequest): |
| """ |
| LLM Chat endpoint for conversational AI assistant |
| |
| Args: |
| request: ChatRequest with message, history, and optional system_prompt |
| |
| Returns: |
| JSON with AI response |
| """ |
| if not GEMINI_AVAILABLE: |
| raise HTTPException( |
| status_code=503, |
| detail="Gemini AI is not available. Install google-generativeai package." |
| ) |
| |
| if not GEMINI_API_KEY: |
| raise HTTPException( |
| status_code=503, |
| detail="GEMINI_API_KEY not configured in environment variables" |
| ) |
| |
| try: |
| |
| system_prompt = request.system_prompt or """You are Pathora AI — a specialist colposcopy assistant. \ |
| Provide expert guidance on examination techniques, findings interpretation, and management guidelines. \ |
| Be professional, evidence-based, and concise.""" |
| |
| |
| available_models = get_supported_gemini_models() |
| if not available_models: |
| raise Exception( |
| "No Gemini models with generateContent are available for this API key. " |
| "Check API key permissions and Gemini API enablement." |
| ) |
|
|
| model_names = get_ordered_model_candidates(available_models) |
| print(f"✅ Chat available models: {available_models}") |
| print(f"✅ Chat candidate models: {model_names}") |
| |
| response_text = None |
| used_model = None |
| |
| for model_name in model_names: |
| try: |
| print(f"🔄 Trying chat model: {model_name}") |
| |
| model = genai.GenerativeModel( |
| model_name=model_name, |
| system_instruction=system_prompt |
| ) |
| |
| |
| chat_history = [] |
| for msg in request.history: |
| role = "model" if msg.role == "bot" else "user" |
| chat_history.append({ |
| "role": role, |
| "parts": [msg.text] |
| }) |
| |
| |
| chat = model.start_chat(history=chat_history) |
| |
| |
| response = chat.send_message(request.message) |
| response_text = response.text |
| used_model = model_name |
| print(f"✅ Successfully used chat model: {model_name}") |
| break |
| except Exception as model_err: |
| err_str = str(model_err) |
| if "429" in err_str or "quota exceeded" in err_str.lower(): |
| QUOTA_BLOCKED_MODELS.add(model_name) |
| print(f"⏭️ Skipping quota-blocked chat model: {model_name}") |
| print(f"⚠️ Chat model {model_name} failed: {err_str}") |
| continue |
| |
| if not response_text: |
| raise Exception("All model attempts failed. Please check API key and model availability.") |
| |
| return JSONResponse({ |
| "status": "success", |
| "response": response_text, |
| "model": used_model |
| }) |
| |
| except Exception as e: |
| error_msg = str(e) |
| print(f"❌ Chat error: {error_msg}") |
| traceback.print_exc() |
| |
| |
| if "API key" in error_msg or "authentication" in error_msg.lower(): |
| detail = "API key authentication failed. Please add GEMINI_API_KEY to HF Space secrets." |
| elif "not found" in error_msg.lower() or "404" in error_msg: |
| detail = f"Gemini model not available. Error: {error_msg}. Please verify API key." |
| else: |
| detail = f"Chat error: {error_msg}" |
| |
| raise HTTPException(status_code=500, detail=detail) |
|
|
|
|
| @app.post("/api/generate-report") |
| async def generate_report_endpoint(request: ReportGenerationRequest): |
| """ |
| Generate colposcopy report using LLM based on patient data and exam findings |
| |
| Args: |
| request: ReportGenerationRequest with patient data, exam findings, and images |
| |
| Returns: |
| JSON with generated report |
| """ |
| if not GEMINI_AVAILABLE: |
| raise HTTPException( |
| status_code=503, |
| detail="Gemini AI is not available. Install google-generativeai package." |
| ) |
| |
| if not GEMINI_API_KEY: |
| raise HTTPException( |
| status_code=503, |
| detail="GEMINI_API_KEY not configured in environment variables" |
| ) |
| |
| try: |
| |
| system_prompt = request.system_prompt or """You are an expert colposcopy AI assistant acting as a specialist gynaecologist. |
| Analyse ALL the clinical data provided and return ONLY a valid JSON object — no markdown, no extra text, no code fences. |
| The JSON must have EXACTLY these 10 keys and no others: |
| { |
| "examQuality": "<Adequate or Inadequate>", |
| "transformationZone": "<I, II, or III>", |
| "acetowL": "<Present or Absent>", |
| "nativeFindings": "<2-3 sentence summary of native view findings>", |
| "aceticFindings": "<2-3 sentence summary of acetic acid findings>", |
| "biopsySites": "<recommended biopsy sites by clock position, or None>", |
| "biopsyNotes": "<brief biopsy notes: lesion grade, type, number of samples>", |
| "colposcopicFindings": "<professional colposcopic findings: 3-4 sentences including Swede score if available>", |
| "treatmentPlan": "<evidence-based treatment plan: 2-3 sentences>", |
| "followUp": "<follow-up schedule with specific timeframes>" |
| }""" |
|
|
| |
| |
| prompt_parts = [] |
| prompt_parts.append("PATIENT DATA:") |
| prompt_parts.append(json.dumps(request.patient_data, indent=2)) |
| prompt_parts.append("\n\nEXAMINATION FINDINGS & OBSERVATIONS:") |
| prompt_parts.append(json.dumps(request.exam_findings, indent=2)) |
| prompt_parts.append(""" |
| |
| Based on all the above clinical data, return ONLY the JSON object with exactly these 10 keys: |
| examQuality, transformationZone, acetowL, nativeFindings, aceticFindings, |
| biopsySites, biopsyNotes, colposcopicFindings, treatmentPlan, followUp |
| |
| Do NOT include any other keys. Do NOT wrap in markdown. Return raw JSON only.""") |
|
|
| full_prompt = "\n".join(prompt_parts) |
|
|
| |
| available_models = get_supported_gemini_models() |
| if not available_models: |
| raise Exception( |
| "No Gemini models with generateContent are available for this API key. " |
| "Check API key permissions and Gemini API enablement." |
| ) |
|
|
| model_names = get_ordered_model_candidates(available_models) |
| print(f"✅ Report available models: {available_models}") |
| print(f"✅ Report candidate models: {model_names}") |
|
|
| response_text = None |
| used_model = None |
|
|
| for model_name in model_names: |
| try: |
| print(f"🔄 Trying model: {model_name}") |
| model = genai.GenerativeModel( |
| model_name=model_name, |
| system_instruction=system_prompt |
| ) |
| response = model.generate_content(full_prompt) |
| response_text = response.text |
| used_model = model_name |
| print(f"✅ Successfully used model: {model_name}") |
| break |
| except Exception as model_err: |
| err_str = str(model_err) |
| if "429" in err_str or "quota exceeded" in err_str.lower(): |
| QUOTA_BLOCKED_MODELS.add(model_name) |
| print(f"⏭️ Skipping quota-blocked report model: {model_name}") |
| print(f"⚠️ Model {model_name} failed: {err_str}") |
| continue |
|
|
| if not response_text: |
| raise Exception("All model attempts failed. Please check API key and model availability.") |
|
|
| |
| try: |
| |
| cleaned_text = response_text.strip() |
| if cleaned_text.startswith('```'): |
| cleaned_text = re.sub(r'^```[a-z]*\n?', '', cleaned_text, flags=re.IGNORECASE) |
| cleaned_text = re.sub(r'\n?```\s*$', '', cleaned_text) |
| cleaned_text = cleaned_text.strip() |
| |
| |
| parsed_json = json.loads(cleaned_text) |
| print(f"✅ Report is valid JSON with keys: {list(parsed_json.keys())}") |
| |
| |
| return JSONResponse({ |
| "status": "success", |
| "report": cleaned_text, |
| "report_json": parsed_json, |
| "model": used_model |
| }) |
| except json.JSONDecodeError as je: |
| print(f"⚠️ Response is not valid JSON: {je}") |
| print(f"Response text: {response_text[:500]}") |
| raise Exception(f"Gemini returned invalid JSON: {str(je)}") |
|
|
| except Exception as e: |
| error_msg = str(e) |
| print(f"❌ Report generation error: {error_msg}") |
| traceback.print_exc() |
|
|
| if "API key" in error_msg or "authentication" in error_msg.lower(): |
| detail = "API key authentication failed. Please check GEMINI_API_KEY in HF Space secrets." |
| elif "not found" in error_msg.lower() or "404" in error_msg: |
| detail = f"Gemini model not available. Error: {error_msg}. Please verify API key has access to Gemini models." |
| else: |
| detail = f"Report generation error: {error_msg}" |
|
|
| raise HTTPException(status_code=500, detail=detail) |
|
|
|
|
|
|
| @app.post("/api/infer-aw-contour") |
| async def infer_aw_contour_endpoint(file: UploadFile = File(...), conf_threshold: float = 0.4): |
| """ |
| Inference endpoint for Acetowhite contour detection |
| |
| Args: |
| file: Image file (jpg, png, etc.) |
| conf_threshold: Confidence threshold for YOLO model (0.0-1.0) |
| |
| Returns: |
| JSON with base64 encoded result image |
| """ |
| try: |
| |
| image_data = await file.read() |
| print(f"✅ File received, size: {len(image_data)} bytes") |
| |
| |
| try: |
| image = Image.open(BytesIO(image_data)) |
| print(f"✅ Image opened, mode: {image.mode}, size: {image.size}") |
| except Exception as e: |
| print(f"❌ Image open error: {e}") |
| traceback.print_exc() |
| raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") |
| |
| |
| |
| if image.mode == 'RGBA': |
| |
| image = image.convert('RGB') |
| elif image.mode != 'RGB': |
| |
| image = image.convert('RGB') |
| |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| print(f"✅ Frame converted, shape: {frame.shape}") |
| |
| |
| print(f"🔄 Running infer_aw_contour with conf_threshold={conf_threshold}") |
| result = infer_aw_contour(frame, conf_threshold=conf_threshold) |
| print(f"✅ Inference complete, detections: {result['detections']}") |
| |
| |
| if result["overlay"] is not None: |
| result_rgb = cv2.cvtColor(result["overlay"], cv2.COLOR_BGR2RGB) |
| result_image = Image.fromarray(result_rgb) |
| |
| |
| buffer = BytesIO() |
| result_image.save(buffer, format="PNG") |
| buffer.seek(0) |
| import base64 |
| image_base64 = base64.b64encode(buffer.getvalue()).decode() |
| print(f"✅ Image encoded to base64, size: {len(image_base64)} chars") |
| else: |
| image_base64 = None |
| print("⚠️ No overlay returned from inference") |
| |
| return JSONResponse({ |
| "status": "success", |
| "message": "Inference completed successfully", |
| "result_image": image_base64, |
| "contours": result["contours"], |
| "detections": result["detections"], |
| "confidence_threshold": conf_threshold |
| }) |
| |
| except Exception as e: |
| print(f"❌ EXCEPTION in infer_aw_contour:") |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Error during inference: {str(e)}") |
|
|
|
|
| @app.post("/api/batch-infer") |
| async def batch_infer(files: list[UploadFile] = File(...), conf_threshold: float = 0.4): |
| """ |
| Batch inference endpoint for multiple images |
| |
| Args: |
| files: List of image files |
| conf_threshold: Confidence threshold for YOLO model |
| |
| Returns: |
| JSON with results for all images |
| """ |
| results = [] |
| |
| for file in files: |
| try: |
| image_data = await file.read() |
| image = Image.open(BytesIO(image_data)) |
| |
| |
| if image.mode == 'RGBA': |
| image = image.convert('RGB') |
| elif image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| result = infer_aw_contour(frame, conf_threshold=conf_threshold) |
| |
| if result["overlay"] is not None: |
| result_rgb = cv2.cvtColor(result["overlay"], cv2.COLOR_BGR2RGB) |
| result_image = Image.fromarray(result_rgb) |
| |
| buffer = BytesIO() |
| result_image.save(buffer, format="PNG") |
| buffer.seek(0) |
| import base64 |
| image_base64 = base64.b64encode(buffer.getvalue()).decode() |
| else: |
| image_base64 = None |
| |
| results.append({ |
| "filename": file.filename, |
| "status": "success", |
| "result_image": image_base64, |
| "contours": result["contours"], |
| "detections": result["detections"] |
| }) |
| |
| except Exception as e: |
| results.append({ |
| "filename": file.filename, |
| "status": "error", |
| "error": str(e) |
| }) |
| |
| return JSONResponse({ |
| "status": "completed", |
| "total_files": len(results), |
| "results": results |
| }) |
|
|
|
|
| @app.post("/infer/image") |
| async def infer_image(file: UploadFile = File(...)): |
| """ |
| Single image inference endpoint for cervix detection/quality. |
| """ |
| try: |
| contents = await file.read() |
| nparr = np.frombuffer(contents, np.uint8) |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| result = analyze_frame(frame) |
|
|
| return JSONResponse(content=result) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/infer/video") |
| async def infer_video(file: UploadFile = File(...)): |
| """ |
| Video inference endpoint for cervix detection/quality (frame-by-frame). |
| """ |
| try: |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: |
| tmp.write(await file.read()) |
| temp_path = tmp.name |
|
|
| cap = cv2.VideoCapture(temp_path) |
|
|
| responses = [] |
| frame_count = 0 |
|
|
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| result = analyze_video_frame(frame) |
| responses.append({ |
| "frame": frame_count, |
| "status": result["status"], |
| "quality_percent": result["quality_percent"] |
| }) |
|
|
| frame_count += 1 |
|
|
| cap.release() |
| os.remove(temp_path) |
|
|
| return JSONResponse(content={ |
| "total_frames": frame_count, |
| "results": responses |
| }) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/api/infer-cervix-bbox") |
| async def infer_cervix_bbox_endpoint(file: UploadFile = File(...), conf_threshold: float = 0.4): |
| """ |
| Cervix bounding box detection endpoint for annotation. |
| Detects cervix location and returns bounding boxes. |
| |
| Args: |
| file: Image file (jpg, png, etc.) |
| conf_threshold: Confidence threshold for YOLO model (0.0-1.0) |
| |
| Returns: |
| JSON with base64 encoded annotated image and bounding box coordinates |
| """ |
| try: |
| |
| image_data = await file.read() |
| |
| |
| try: |
| image = Image.open(BytesIO(image_data)) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") |
| |
| |
| if image.mode == 'RGBA': |
| image = image.convert('RGB') |
| elif image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| |
| |
| result = infer_cervix_bbox(frame, conf_threshold=conf_threshold) |
| |
| |
| if result["overlay"] is not None: |
| result_rgb = cv2.cvtColor(result["overlay"], cv2.COLOR_BGR2RGB) |
| result_image = Image.fromarray(result_rgb) |
| |
| |
| buffer = BytesIO() |
| result_image.save(buffer, format="PNG") |
| buffer.seek(0) |
| import base64 |
| image_base64 = base64.b64encode(buffer.getvalue()).decode() |
| else: |
| image_base64 = None |
| |
| return JSONResponse({ |
| "status": "success", |
| "message": "Cervix bounding box detection completed", |
| "result_image": image_base64, |
| "bounding_boxes": result["bounding_boxes"], |
| "detections": result["detections"], |
| "frame_width": result["frame_width"], |
| "frame_height": result["frame_height"], |
| "confidence_threshold": conf_threshold |
| }) |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error during cervix bbox inference: {str(e)}") |
|
|
|
|
| |
| frontend_dist = os.path.join(os.path.dirname(__file__), "..", "dist") |
| if os.path.isdir(frontend_dist): |
| app.mount("/", SPAStaticFiles(directory=frontend_dist, html=True), name="frontend") |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |