from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from api.inference import predict_claim, load_models, get_available_models from api.schema import ClaimRequest, PredictionResponse, EvidenceItem from api.rag import load_rag_data, get_rag_result from dotenv import load_dotenv import asyncio import os from contextlib import asynccontextmanager load_dotenv() STATIC_DIR = "frontend/dist/frontend/browser" @asynccontextmanager async def lifespan(app: FastAPI): print("Loading models...") load_models() print("Building RAG index...") load_rag_data() print("Startup complete.") yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") def health(): return {"status": "ok", "available_models": get_available_models()} @app.get("/models") def list_models(): return {"available_models": get_available_models()} @app.post("/predict") async def predict_claim_endpoint(claim_request: ClaimRequest): try: claim = claim_request.claim.strip() if not claim: raise HTTPException(status_code=400, detail="Claim text cannot be empty") result = await asyncio.to_thread(predict_claim, claim, claim_request.model) predicted_label = result["prediction"] evidence_raw, justification = await asyncio.to_thread( get_rag_result, claim, predicted_label ) evidence = [EvidenceItem(**e) for e in evidence_raw] if evidence_raw else None return PredictionResponse( prediction=predicted_label, probabilities=result.get("probabilities"), evidence=evidence, justification=justification, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing claim: {str(e)}") if os.path.exists(STATIC_DIR): @app.get("/{full_path:path}") async def serve_spa(full_path: str): file_path = os.path.join(STATIC_DIR, full_path) if full_path and os.path.exists(file_path) and os.path.isfile(file_path): return FileResponse(file_path) return FileResponse(os.path.join(STATIC_DIR, "index.html"))