Distopia22's picture
Production-ready Medical Coding API with Phi-3 support
d03f587
# api.py
import re
import json
import gc
import time
from typing import Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from .model_loader import load_model_and_tokenizer
from .prompt_template import PROMPT_TEMPLATE
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Medical Coding API",
description="Extract ICD-10 and CPT codes from clinical notes using AI",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
class NoteRequest(BaseModel):
note: str = Field(
...,
min_length=10,
max_length=50000,
description="Clinical provider note (10-50,000 characters)"
)
class Config:
json_schema_extra = {
"example": {
"note": "Patient presents with essential hypertension. BP 160/95. Prescribed lisinopril 10mg daily. Office visit for established patient."
}
}
class CodingResponse(BaseModel):
result: dict = Field(..., description="Extracted ICD-10 and CPT codes")
raw_output: str = Field(..., description="Raw model output")
note_length: int = Field(..., description="Length of input note in characters")
truncated: bool = Field(..., description="Whether note was truncated")
processing_time: float = Field(..., description="Time taken to process in seconds")
# Global variables for lazy loading
_gen_pipeline = None
_tokenizer = None
_model_load_time = None
def get_model():
"""Lazy load model on first request with error handling."""
global _gen_pipeline, _tokenizer, _model_load_time
if _gen_pipeline is None:
logger.info("๐Ÿ”„ Loading model for the first time...")
start_time = time.time()
try:
_gen_pipeline, _tokenizer = load_model_and_tokenizer()
_model_load_time = time.time() - start_time
logger.info(f"โœ… Model loaded in {_model_load_time:.2f} seconds")
except Exception as e:
logger.error(f"โŒ Failed to load model: {str(e)}")
raise HTTPException(
status_code=503,
detail=f"Model loading failed: {str(e)}. Please try again in a few moments."
)
return _gen_pipeline, _tokenizer
def extract_json_from_text(text: str) -> Optional[str]:
"""Extract JSON object from text using brace counting."""
start_idx = text.find('{')
if start_idx == -1:
return None
brace_count = 0
for i in range(start_idx, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
return text[start_idx:i+1]
return None
def truncate_note(note: str, max_chars: int = 10000) -> str:
"""Truncate note to prevent token limit issues."""
if len(note) <= max_chars:
return note
logger.warning(f"Note truncated from {len(note)} to {max_chars} characters")
return note[:max_chars]
# ===== ENDPOINTS =====
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": "Medical Coding API",
"version": "1.0.0",
"description": "Extract ICD-10 and CPT codes from clinical notes",
"model": "RayyanAhmed9477/med-coding (Phi-3 based)",
"endpoints": {
"/predict": "POST - Extract medical codes from clinical note",
"/health": "GET - Check API health status",
"/docs": "GET - Interactive API documentation",
"/metrics": "GET - API usage metrics"
},
"usage": {
"endpoint": "/predict",
"method": "POST",
"body": {"note": "Your clinical note here (10-50,000 chars)"},
"max_note_length": "50,000 characters (~10,000 words)"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model": "RayyanAhmed9477/med-coding",
"model_loaded": _gen_pipeline is not None,
"model_load_time": f"{_model_load_time:.2f}s" if _model_load_time else "not loaded yet"
}
@app.get("/metrics")
async def metrics():
"""Get API usage metrics."""
return {
"model_loaded": _gen_pipeline is not None,
"model_load_time_seconds": _model_load_time,
"status": "operational"
}
@app.post("/predict", response_model=CodingResponse)
async def predict(request: NoteRequest):
"""
Extract ICD-10 and CPT codes from clinical notes.
**Input:** Clinical note (10-50,000 characters)
**Output:** JSON with extracted codes:
- icd10_codes: List of ICD-10 diagnosis codes
- cpt_codes: List of CPT procedure codes
**Note:** First request may take 30-60 seconds as model loads into memory.
Subsequent requests will be faster (2-10 seconds).
"""
start_time = time.time()
try:
# Validate input
note = request.note.strip()
if not note:
raise HTTPException(status_code=400, detail="Empty note provided")
# Load model (lazy loading)
logger.info(f"๐Ÿ“ Processing note ({len(note)} characters)")
gen_pipeline, tokenizer = get_model()
# Truncate if needed
original_length = len(note)
note_truncated = truncate_note(note, max_chars=10000)
# Build prompt
prompt = PROMPT_TEMPLATE.format(note=note_truncated)
logger.info(f"๐Ÿ”ฎ Generating prediction (prompt length: {len(prompt)} chars)")
# Generate prediction
outputs = gen_pipeline(
prompt,
max_new_tokens=600,
do_sample=False,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.1,
top_p=0.95,
return_full_text=False
)
# Extract generated text
if isinstance(outputs, list) and len(outputs) > 0:
text = outputs[0].get("generated_text", "")
elif isinstance(outputs, dict):
text = outputs.get("generated_text", "")
else:
text = str(outputs)
logger.info(f"๐Ÿ“ค Model output length: {len(text)} characters")
# Remove prompt if present
if prompt in text:
text = text.replace(prompt, "").strip()
# Extract JSON
json_str = extract_json_from_text(text)
if json_str is None:
logger.error(f"No JSON found in output: {text[:500]}")
raise HTTPException(
status_code=500,
detail={
"error": "No valid JSON found in model output",
"raw_output_preview": text[:300],
"suggestion": "Model may need fine-tuning or prompt adjustment"
}
)
# Parse JSON
try:
parsed = json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"JSON parse error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": f"Invalid JSON format: {str(e)}",
"json_preview": json_str[:300]
}
)
# Validate response structure
if not isinstance(parsed, dict):
raise HTTPException(
status_code=500,
detail="Model output is not a valid JSON object"
)
# Clean up memory
gc.collect()
processing_time = time.time() - start_time
logger.info(f"โœ… Prediction completed in {processing_time:.2f} seconds")
return CodingResponse(
result=parsed,
raw_output=text,
note_length=original_length,
truncated=original_length > 10000,
processing_time=round(processing_time, 2)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"โŒ Prediction failed: {str(e)}", exc_info=True)
gc.collect()
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler for unhandled errors."""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"detail": "Internal server error",
"error": str(exc),
"path": str(request.url)
}
)
# Startup event
@app.on_event("startup")
async def startup_event():
"""Log startup information."""
logger.info("=" * 60)
logger.info("๐Ÿš€ Medical Coding API Starting...")
logger.info("=" * 60)
logger.info("โณ Model will be loaded on first /predict request")
logger.info("๐Ÿ“š API Documentation: /docs")
logger.info("=" * 60)