Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List | |
| import json | |
| import os | |
| import logging | |
| # Import the existing symptom checker logic | |
| from api_symptom_checker import load_artifacts, predict_symptoms_json | |
| import numpy as np | |
| def safe_predict_symptoms_json(symptoms, model, label_encoder, feature_names): | |
| """Safe prediction that only uses diseases the label encoder knows about""" | |
| if not symptoms: | |
| return {"error": "No symptoms provided"} | |
| # Build feature vector (convert display names back to feature names) | |
| feature_dict = {name.replace("_", " ").title(): name for name in feature_names} | |
| x = np.zeros(len(feature_names)) | |
| matched_symptoms = [] | |
| for symptom in symptoms: | |
| if symptom in feature_dict: | |
| feature_name = feature_dict[symptom] | |
| if feature_name in feature_names: | |
| idx = feature_names.index(feature_name) | |
| x[idx] = 1.0 | |
| matched_symptoms.append(symptom) | |
| if len(matched_symptoms) == 0: | |
| return {"error": "No valid symptoms found"} | |
| x = x.reshape(1, -1) | |
| # Get predictions - but only use classes the label encoder knows about | |
| proba = model.predict_proba(x)[0] | |
| # SAFETY: Only use the first len(label_encoder.classes_) predictions | |
| max_valid_class = len(label_encoder.classes_) | |
| valid_proba = proba[:max_valid_class] # Only use valid classes | |
| # Get top 3 from valid classes only | |
| top3_idx = np.argsort(valid_proba)[-3:][::-1] | |
| predictions = [] | |
| for rank, idx in enumerate(top3_idx, 1): | |
| disease_name = label_encoder.inverse_transform([idx])[0] | |
| confidence = float(valid_proba[idx]) | |
| predictions.append({ | |
| "rank": rank, | |
| "disease": disease_name, | |
| "confidence": confidence, | |
| "confidence_percent": round(confidence * 100, 2) | |
| }) | |
| return { | |
| "input_symptoms": matched_symptoms, | |
| "primary_diagnosis": predictions[0], | |
| "top_predictions": predictions, | |
| "model_confidence": "high" if predictions[0]["confidence"] > 0.7 else "medium" if predictions[0]["confidence"] > 0.4 else "low" | |
| } | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Symptom Checker API", | |
| description="AI-powered symptom analysis service", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure this properly for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model artifacts | |
| model = None | |
| label_encoder = None | |
| feature_names = None | |
| # Pydantic models for request/response | |
| class SymptomRequest(BaseModel): | |
| symptoms: List[str] | |
| class PredictionItem(BaseModel): | |
| rank: int | |
| disease: str | |
| confidence: float | |
| confidence_percent: float | |
| class SymptomResponse(BaseModel): | |
| input_symptoms: List[str] | |
| primary_diagnosis: PredictionItem | |
| top_predictions: List[PredictionItem] | |
| model_confidence: str | |
| class AvailableSymptomsResponse(BaseModel): | |
| success: bool = True | |
| symptoms: List[str] | |
| total_symptoms: int | |
| async def startup_event(): | |
| """Load model artifacts on startup""" | |
| global model, label_encoder, feature_names | |
| try: | |
| logger.info("Loading symptom checker model artifacts...") | |
| model, label_encoder, feature_names = load_artifacts("symptom_model") | |
| logger.info(f"Model loaded successfully with {len(feature_names)} features") | |
| except Exception as e: | |
| logger.error(f"Failed to load model artifacts: {e}") | |
| raise e | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Symptom Checker API", | |
| "version": "1.0.0", | |
| "endpoints": ["/health", "/api/symptoms", "/api/check-symptoms"] | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return { | |
| "status": "healthy", | |
| "service": "symptom-checker", | |
| "model_loaded": model is not None, | |
| "features_count": len(feature_names) if feature_names else 0 | |
| } | |
| async def get_available_symptoms(): | |
| """Get list of all available symptoms that the model can recognize""" | |
| if feature_names is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Clean up symptom names for display | |
| clean_symptoms = [] | |
| for symptom in feature_names: | |
| # Convert from feature format to readable format | |
| clean_symptom = symptom.replace('_', ' ').title() | |
| clean_symptoms.append(clean_symptom) | |
| return AvailableSymptomsResponse( | |
| success=True, | |
| symptoms=sorted(clean_symptoms), | |
| total_symptoms=len(clean_symptoms) | |
| ) | |
| async def check_symptoms(request: SymptomRequest): | |
| """Analyze symptoms and return disease predictions""" | |
| global model, label_encoder, feature_names | |
| if model is None or label_encoder is None or feature_names is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not request.symptoms: | |
| raise HTTPException(status_code=400, detail="No symptoms provided") | |
| try: | |
| # Convert display names back to feature names (Title Case With Spaces -> underscore_format) | |
| feature_symptoms = [] | |
| for symptom in request.symptoms: | |
| # Convert "Anxiety And Nervousness" -> "anxiety_and_nervousness" | |
| feature_format = symptom.lower().replace(' ', '_') | |
| feature_symptoms.append(feature_format) | |
| # Use the SAFE prediction logic that handles class mismatch | |
| result = safe_predict_symptoms_json(request.symptoms, model, label_encoder, feature_names) | |
| if "error" in result: | |
| raise HTTPException(status_code=400, detail=result["error"]) | |
| # Convert to response format | |
| predictions = [] | |
| for pred in result["top_predictions"]: | |
| predictions.append(PredictionItem( | |
| rank=pred["rank"], | |
| disease=pred["disease"], | |
| confidence=pred["confidence"], | |
| confidence_percent=pred["confidence_percent"] | |
| )) | |
| # Return format that matches Flutter's SymptomCheckResponse expectations | |
| return { | |
| "success": True, | |
| "predictions": [ | |
| { | |
| "rank": pred["rank"], | |
| "disease": pred["disease"], | |
| "confidence": pred["confidence"], | |
| "confidence_percent": f"{pred['confidence_percent']:.2f}%" | |
| } | |
| for pred in result["top_predictions"] | |
| ], | |
| "input_symptoms": request.symptoms, | |
| "primary_diagnosis": result["primary_diagnosis"]["disease"], | |
| "model_confidence": result["model_confidence"] | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during symptom prediction: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Use port 7860 for Hugging Face Spaces, fallback to 8002 for local development | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False) |