Tantawi's picture
Upload 13 files
f2a4578 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import json
import os
import logging
# Import the existing symptom checker logic
from api_symptom_checker import load_artifacts, predict_symptoms_json
import numpy as np
def safe_predict_symptoms_json(symptoms, model, label_encoder, feature_names):
"""Safe prediction that only uses diseases the label encoder knows about"""
if not symptoms:
return {"error": "No symptoms provided"}
# Build feature vector (convert display names back to feature names)
feature_dict = {name.replace("_", " ").title(): name for name in feature_names}
x = np.zeros(len(feature_names))
matched_symptoms = []
for symptom in symptoms:
if symptom in feature_dict:
feature_name = feature_dict[symptom]
if feature_name in feature_names:
idx = feature_names.index(feature_name)
x[idx] = 1.0
matched_symptoms.append(symptom)
if len(matched_symptoms) == 0:
return {"error": "No valid symptoms found"}
x = x.reshape(1, -1)
# Get predictions - but only use classes the label encoder knows about
proba = model.predict_proba(x)[0]
# SAFETY: Only use the first len(label_encoder.classes_) predictions
max_valid_class = len(label_encoder.classes_)
valid_proba = proba[:max_valid_class] # Only use valid classes
# Get top 3 from valid classes only
top3_idx = np.argsort(valid_proba)[-3:][::-1]
predictions = []
for rank, idx in enumerate(top3_idx, 1):
disease_name = label_encoder.inverse_transform([idx])[0]
confidence = float(valid_proba[idx])
predictions.append({
"rank": rank,
"disease": disease_name,
"confidence": confidence,
"confidence_percent": round(confidence * 100, 2)
})
return {
"input_symptoms": matched_symptoms,
"primary_diagnosis": predictions[0],
"top_predictions": predictions,
"model_confidence": "high" if predictions[0]["confidence"] > 0.7 else "medium" if predictions[0]["confidence"] > 0.4 else "low"
}
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Symptom Checker API",
description="AI-powered symptom analysis service",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure this properly for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model artifacts
model = None
label_encoder = None
feature_names = None
# Pydantic models for request/response
class SymptomRequest(BaseModel):
symptoms: List[str]
class PredictionItem(BaseModel):
rank: int
disease: str
confidence: float
confidence_percent: float
class SymptomResponse(BaseModel):
input_symptoms: List[str]
primary_diagnosis: PredictionItem
top_predictions: List[PredictionItem]
model_confidence: str
class AvailableSymptomsResponse(BaseModel):
success: bool = True
symptoms: List[str]
total_symptoms: int
@app.on_event("startup")
async def startup_event():
"""Load model artifacts on startup"""
global model, label_encoder, feature_names
try:
logger.info("Loading symptom checker model artifacts...")
model, label_encoder, feature_names = load_artifacts("symptom_model")
logger.info(f"Model loaded successfully with {len(feature_names)} features")
except Exception as e:
logger.error(f"Failed to load model artifacts: {e}")
raise e
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Symptom Checker API",
"version": "1.0.0",
"endpoints": ["/health", "/api/symptoms", "/api/check-symptoms"]
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {
"status": "healthy",
"service": "symptom-checker",
"model_loaded": model is not None,
"features_count": len(feature_names) if feature_names else 0
}
@app.get("/api/symptoms", response_model=AvailableSymptomsResponse)
async def get_available_symptoms():
"""Get list of all available symptoms that the model can recognize"""
if feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Clean up symptom names for display
clean_symptoms = []
for symptom in feature_names:
# Convert from feature format to readable format
clean_symptom = symptom.replace('_', ' ').title()
clean_symptoms.append(clean_symptom)
return AvailableSymptomsResponse(
success=True,
symptoms=sorted(clean_symptoms),
total_symptoms=len(clean_symptoms)
)
@app.post("/api/check-symptoms")
async def check_symptoms(request: SymptomRequest):
"""Analyze symptoms and return disease predictions"""
global model, label_encoder, feature_names
if model is None or label_encoder is None or feature_names is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not request.symptoms:
raise HTTPException(status_code=400, detail="No symptoms provided")
try:
# Convert display names back to feature names (Title Case With Spaces -> underscore_format)
feature_symptoms = []
for symptom in request.symptoms:
# Convert "Anxiety And Nervousness" -> "anxiety_and_nervousness"
feature_format = symptom.lower().replace(' ', '_')
feature_symptoms.append(feature_format)
# Use the SAFE prediction logic that handles class mismatch
result = safe_predict_symptoms_json(request.symptoms, model, label_encoder, feature_names)
if "error" in result:
raise HTTPException(status_code=400, detail=result["error"])
# Convert to response format
predictions = []
for pred in result["top_predictions"]:
predictions.append(PredictionItem(
rank=pred["rank"],
disease=pred["disease"],
confidence=pred["confidence"],
confidence_percent=pred["confidence_percent"]
))
# Return format that matches Flutter's SymptomCheckResponse expectations
return {
"success": True,
"predictions": [
{
"rank": pred["rank"],
"disease": pred["disease"],
"confidence": pred["confidence"],
"confidence_percent": f"{pred['confidence_percent']:.2f}%"
}
for pred in result["top_predictions"]
],
"input_symptoms": request.symptoms,
"primary_diagnosis": result["primary_diagnosis"]["disease"],
"model_confidence": result["model_confidence"]
}
except Exception as e:
logger.error(f"Error during symptom prediction: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
import os
# Use port 7860 for Hugging Face Spaces, fallback to 8002 for local development
port = int(os.getenv("PORT", 7860))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)