Spaces:
Sleeping
Sleeping
| # api.py | |
| import re | |
| import json | |
| import gc | |
| import time | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from .model_loader import load_model_and_tokenizer | |
| from .prompt_template import PROMPT_TEMPLATE | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Medical Coding API", | |
| description="Extract ICD-10 and CPT codes from clinical notes using AI", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| class NoteRequest(BaseModel): | |
| note: str = Field( | |
| ..., | |
| min_length=10, | |
| max_length=50000, | |
| description="Clinical provider note (10-50,000 characters)" | |
| ) | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "note": "Patient presents with essential hypertension. BP 160/95. Prescribed lisinopril 10mg daily. Office visit for established patient." | |
| } | |
| } | |
| class CodingResponse(BaseModel): | |
| result: dict = Field(..., description="Extracted ICD-10 and CPT codes") | |
| raw_output: str = Field(..., description="Raw model output") | |
| note_length: int = Field(..., description="Length of input note in characters") | |
| truncated: bool = Field(..., description="Whether note was truncated") | |
| processing_time: float = Field(..., description="Time taken to process in seconds") | |
| # Global variables for lazy loading | |
| _gen_pipeline = None | |
| _tokenizer = None | |
| _model_load_time = None | |
| def get_model(): | |
| """Lazy load model on first request with error handling.""" | |
| global _gen_pipeline, _tokenizer, _model_load_time | |
| if _gen_pipeline is None: | |
| logger.info("๐ Loading model for the first time...") | |
| start_time = time.time() | |
| try: | |
| _gen_pipeline, _tokenizer = load_model_and_tokenizer() | |
| _model_load_time = time.time() - start_time | |
| logger.info(f"โ Model loaded in {_model_load_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"โ Failed to load model: {str(e)}") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model loading failed: {str(e)}. Please try again in a few moments." | |
| ) | |
| return _gen_pipeline, _tokenizer | |
| def extract_json_from_text(text: str) -> Optional[str]: | |
| """Extract JSON object from text using brace counting.""" | |
| start_idx = text.find('{') | |
| if start_idx == -1: | |
| return None | |
| brace_count = 0 | |
| for i in range(start_idx, len(text)): | |
| if text[i] == '{': | |
| brace_count += 1 | |
| elif text[i] == '}': | |
| brace_count -= 1 | |
| if brace_count == 0: | |
| return text[start_idx:i+1] | |
| return None | |
| def truncate_note(note: str, max_chars: int = 10000) -> str: | |
| """Truncate note to prevent token limit issues.""" | |
| if len(note) <= max_chars: | |
| return note | |
| logger.warning(f"Note truncated from {len(note)} to {max_chars} characters") | |
| return note[:max_chars] | |
| # ===== ENDPOINTS ===== | |
| async def root(): | |
| """Root endpoint with API information.""" | |
| return { | |
| "name": "Medical Coding API", | |
| "version": "1.0.0", | |
| "description": "Extract ICD-10 and CPT codes from clinical notes", | |
| "model": "RayyanAhmed9477/med-coding (Phi-3 based)", | |
| "endpoints": { | |
| "/predict": "POST - Extract medical codes from clinical note", | |
| "/health": "GET - Check API health status", | |
| "/docs": "GET - Interactive API documentation", | |
| "/metrics": "GET - API usage metrics" | |
| }, | |
| "usage": { | |
| "endpoint": "/predict", | |
| "method": "POST", | |
| "body": {"note": "Your clinical note here (10-50,000 chars)"}, | |
| "max_note_length": "50,000 characters (~10,000 words)" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "model": "RayyanAhmed9477/med-coding", | |
| "model_loaded": _gen_pipeline is not None, | |
| "model_load_time": f"{_model_load_time:.2f}s" if _model_load_time else "not loaded yet" | |
| } | |
| async def metrics(): | |
| """Get API usage metrics.""" | |
| return { | |
| "model_loaded": _gen_pipeline is not None, | |
| "model_load_time_seconds": _model_load_time, | |
| "status": "operational" | |
| } | |
| async def predict(request: NoteRequest): | |
| """ | |
| Extract ICD-10 and CPT codes from clinical notes. | |
| **Input:** Clinical note (10-50,000 characters) | |
| **Output:** JSON with extracted codes: | |
| - icd10_codes: List of ICD-10 diagnosis codes | |
| - cpt_codes: List of CPT procedure codes | |
| **Note:** First request may take 30-60 seconds as model loads into memory. | |
| Subsequent requests will be faster (2-10 seconds). | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Validate input | |
| note = request.note.strip() | |
| if not note: | |
| raise HTTPException(status_code=400, detail="Empty note provided") | |
| # Load model (lazy loading) | |
| logger.info(f"๐ Processing note ({len(note)} characters)") | |
| gen_pipeline, tokenizer = get_model() | |
| # Truncate if needed | |
| original_length = len(note) | |
| note_truncated = truncate_note(note, max_chars=10000) | |
| # Build prompt | |
| prompt = PROMPT_TEMPLATE.format(note=note_truncated) | |
| logger.info(f"๐ฎ Generating prediction (prompt length: {len(prompt)} chars)") | |
| # Generate prediction | |
| outputs = gen_pipeline( | |
| prompt, | |
| max_new_tokens=600, | |
| do_sample=False, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| temperature=0.1, | |
| top_p=0.95, | |
| return_full_text=False | |
| ) | |
| # Extract generated text | |
| if isinstance(outputs, list) and len(outputs) > 0: | |
| text = outputs[0].get("generated_text", "") | |
| elif isinstance(outputs, dict): | |
| text = outputs.get("generated_text", "") | |
| else: | |
| text = str(outputs) | |
| logger.info(f"๐ค Model output length: {len(text)} characters") | |
| # Remove prompt if present | |
| if prompt in text: | |
| text = text.replace(prompt, "").strip() | |
| # Extract JSON | |
| json_str = extract_json_from_text(text) | |
| if json_str is None: | |
| logger.error(f"No JSON found in output: {text[:500]}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "No valid JSON found in model output", | |
| "raw_output_preview": text[:300], | |
| "suggestion": "Model may need fine-tuning or prompt adjustment" | |
| } | |
| ) | |
| # Parse JSON | |
| try: | |
| parsed = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON parse error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": f"Invalid JSON format: {str(e)}", | |
| "json_preview": json_str[:300] | |
| } | |
| ) | |
| # Validate response structure | |
| if not isinstance(parsed, dict): | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Model output is not a valid JSON object" | |
| ) | |
| # Clean up memory | |
| gc.collect() | |
| processing_time = time.time() - start_time | |
| logger.info(f"โ Prediction completed in {processing_time:.2f} seconds") | |
| return CodingResponse( | |
| result=parsed, | |
| raw_output=text, | |
| note_length=original_length, | |
| truncated=original_length > 10000, | |
| processing_time=round(processing_time, 2) | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"โ Prediction failed: {str(e)}", exc_info=True) | |
| gc.collect() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| """Global exception handler for unhandled errors.""" | |
| logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "detail": "Internal server error", | |
| "error": str(exc), | |
| "path": str(request.url) | |
| } | |
| ) | |
| # Startup event | |
| async def startup_event(): | |
| """Log startup information.""" | |
| logger.info("=" * 60) | |
| logger.info("๐ Medical Coding API Starting...") | |
| logger.info("=" * 60) | |
| logger.info("โณ Model will be loaded on first /predict request") | |
| logger.info("๐ API Documentation: /docs") | |
| logger.info("=" * 60) | |