mbochniak01
Add telemetry layer: in-memory counters + HF Dataset persistence
c79d967
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException
from huggingface_hub import InferenceClient
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import telemetry
from config import DOMAIN_CLIENTS, CLIENT_DOMAIN, DISPLAY_NAMES
from grader import get_embedder, get_nli_model
from pipeline import run, _build_index, clear_index_cache
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
UI_DIR = Path(__file__).parent.parent / "ui"
@asynccontextmanager
async def lifespan(app: FastAPI):
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError("HF_TOKEN not set")
app.state.hf_client = InferenceClient(token=hf_token)
embedder = get_embedder()
get_nli_model()
for domain in DOMAIN_CLIENTS:
_build_index(domain, embedder)
log.info("Models and KB indexes pre-warmed. Ready.")
yield
app = FastAPI(title="AI Response Validator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
class QueryRequest(BaseModel):
query: str
client: str
class QueryResponse(BaseModel):
query: str
client: str
client_display: str
answer: str
flagged: bool
sources: list[dict]
evaluation: dict
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/config")
def get_config():
"""Domain/client structure for the UI switcher."""
return {
"domains": {
domain: [{"id": c, "display": DISPLAY_NAMES[c]} for c in clients]
for domain, clients in DOMAIN_CLIENTS.items()
}
}
@app.post("/refresh-cache")
def refresh_cache():
"""Evict KB index cache and rebuild all domain indexes from disk."""
evicted = clear_index_cache()
embedder = get_embedder()
for domain in DOMAIN_CLIENTS:
_build_index(domain, embedder)
log.info("Cache refreshed. Rebuilt indexes for: %s", list(DOMAIN_CLIENTS))
return {"refreshed": evicted, "rebuilt": list(DOMAIN_CLIENTS)}
@app.get("/metrics")
def get_metrics():
"""Live session stats from in-memory counters — resets on restart."""
return telemetry.live_stats()
@app.get("/report")
def get_report():
"""Accumulated stats from HF Dataset shards — persists across restarts."""
return telemetry.persistent_report()
@app.post("/query", response_model=QueryResponse)
def handle_query(req: QueryRequest):
if req.client not in CLIENT_DOMAIN:
raise HTTPException(status_code=400, detail=f"Unknown client: {req.client!r}")
if not req.query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty")
result = run(
query=req.query.strip(),
client=req.client,
hf_client=app.state.hf_client,
)
if not result.grade_report.overall:
failed = [r.metric for r in result.grade_report.results if not r.passed]
log.warning("EVAL_FAIL client=%s failed_metrics=%s query=%r",
req.client, failed, req.query.strip()[:80])
return result.response_payload
app.mount("/static", StaticFiles(directory=UI_DIR), name="static")
@app.get("/")
def root():
return FileResponse(UI_DIR / "index.html")