Triple-R-Checker / api /main.py
=Apyhtml20
Add RAG module: PostgreSQL + pgvector in same container
45fa780
Raw
History Blame Contribute Delete
2.44 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from api.inference import predict_claim, load_models, get_available_models
from api.schema import ClaimRequest, PredictionResponse, EvidenceItem
from api.rag import load_rag_data, get_rag_result
from dotenv import load_dotenv
import asyncio
import os
from contextlib import asynccontextmanager
load_dotenv()
STATIC_DIR = "frontend/dist/frontend/browser"
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading models...")
load_models()
print("Building RAG index...")
load_rag_data()
print("Startup complete.")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "ok", "available_models": get_available_models()}
@app.get("/models")
def list_models():
return {"available_models": get_available_models()}
@app.post("/predict")
async def predict_claim_endpoint(claim_request: ClaimRequest):
try:
claim = claim_request.claim.strip()
if not claim:
raise HTTPException(status_code=400, detail="Claim text cannot be empty")
result = await asyncio.to_thread(predict_claim, claim, claim_request.model)
predicted_label = result["prediction"]
evidence_raw, justification = await asyncio.to_thread(
get_rag_result, claim, predicted_label
)
evidence = [EvidenceItem(**e) for e in evidence_raw] if evidence_raw else None
return PredictionResponse(
prediction=predicted_label,
probabilities=result.get("probabilities"),
evidence=evidence,
justification=justification,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing claim: {str(e)}")
if os.path.exists(STATIC_DIR):
@app.get("/{full_path:path}")
async def serve_spa(full_path: str):
file_path = os.path.join(STATIC_DIR, full_path)
if full_path and os.path.exists(file_path) and os.path.isfile(file_path):
return FileResponse(file_path)
return FileResponse(os.path.join(STATIC_DIR, "index.html"))