Spaces:
Sleeping
Sleeping
| """ | |
| main.py β PsyPredict FastAPI Application (Production) | |
| Replaces Flask. Key features: | |
| - Async request handling (FastAPI + Uvicorn) | |
| - CORS middleware | |
| - Rate limiting (SlowAPI) | |
| - Structured logging (Python logging) | |
| - Startup model pre-warming | |
| - Graceful shutdown (Ollama client cleanup) | |
| - FastAPI auto docs at /docs (Swagger) and /redoc | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import sys | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.util import get_remote_address | |
| from app.config import get_settings | |
| from app.api.endpoints.facial import router as facial_router | |
| from app.api.endpoints.remedies import router as remedies_router | |
| from app.api.endpoints.therapist import router as therapist_router | |
| from app.api.endpoints.analysis import router as analysis_router | |
| settings = get_settings() | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig( | |
| level=getattr(logging, settings.LOG_LEVEL, logging.INFO), | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Rate Limiter | |
| # --------------------------------------------------------------------------- | |
| limiter = Limiter(key_func=get_remote_address, default_limits=[settings.RATE_LIMIT]) | |
| # --------------------------------------------------------------------------- | |
| # Lifespan (startup / shutdown events) | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Startup: pre-warm models (DistilBERT + Crisis classifier). | |
| Shutdown: close Ollama async client. | |
| """ | |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") | |
| logger.info("π PsyPredict v2.0 β Production Backend") | |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") | |
| logger.info("Config: Ollama=%s model=%s", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL) | |
| import asyncio as _asyncio | |
| # Pre-warm DistilBERT text emotion model (in background) | |
| logger.info("Initializing DistilBERT text emotion model (background)...") | |
| from app.services.text_emotion_engine import initialize as init_text | |
| _asyncio.create_task(_asyncio.to_thread(init_text, settings.DISTILBERT_MODEL)) | |
| # Pre-warm Crisis zero-shot classifier (in background) | |
| logger.info("Initializing crisis detection classifier (background)...") | |
| from app.services.crisis_engine import initialize_crisis_classifier | |
| _asyncio.create_task(_asyncio.to_thread(initialize_crisis_classifier)) | |
| # Check Ollama availability (non-blocking warn only) | |
| from app.services.ollama_engine import ollama_engine | |
| reachable = await ollama_engine.is_reachable() | |
| if reachable: | |
| logger.info("β Ollama reachable at %s (model: %s)", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL) | |
| else: | |
| logger.warning( | |
| "β οΈ Ollama NOT reachable at %s β chat will return fallback responses. " | |
| "Run: ollama serve && ollama pull %s", | |
| settings.OLLAMA_BASE_URL, | |
| settings.OLLAMA_MODEL, | |
| ) | |
| logger.info("β Startup complete. Listening on port 7860.") | |
| logger.info(" Docs: http://localhost:7860/docs") | |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") | |
| yield # ββ Application Running ββ | |
| logger.info("Shutting down PsyPredict backend...") | |
| await ollama_engine.close() | |
| logger.info("Goodbye.") | |
| # --------------------------------------------------------------------------- | |
| # FastAPI App | |
| # --------------------------------------------------------------------------- | |
| def create_app() -> FastAPI: | |
| app = FastAPI( | |
| title="PsyPredict API", | |
| description=( | |
| "Production-grade multimodal mental health AI system. " | |
| "Powered by Llama3 (Ollama) + DistilBERT + Keras CNN facial emotion model." | |
| ), | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # ββ Rate Limiter βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # ββ CORS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Tighten to specific origin in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Global Exception Handler βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error("Unhandled exception: %s | path=%s", exc, request.url.path) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error. Please try again."}, | |
| ) | |
| # ββ Routers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.include_router(facial_router, prefix="/api", tags=["Facial Emotion"]) | |
| app.include_router(remedies_router, prefix="/api", tags=["Remedies"]) | |
| app.include_router(therapist_router, prefix="/api", tags=["AI Therapist"]) | |
| app.include_router(analysis_router, prefix="/api", tags=["Text Analysis & Health"]) | |
| return app | |
| app = create_app() | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app.main:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=False, | |
| log_level=settings.LOG_LEVEL.lower(), | |
| workers=1, # Keep at 1: models are singletons loaded in memory | |
| ) | |