Spaces:
Build error
Build error
| """ | |
| Telemetry: in-memory counters + HF Dataset persistence. | |
| In-memory β rolling deque(500), per-metric pass counters, avg latency. | |
| Resets on restart. Powers /metrics (live session stats). | |
| Persistent β events flushed as JSONL shards to TELEMETRY_REPO HF dataset | |
| every FLUSH_EVERY queries. Powers /report (accumulated history). | |
| Falls back to in-memory if TELEMETRY_REPO is unset or write fails. | |
| TELEMETRY_REPO is inferred from HF Spaces env vars: | |
| SPACE_AUTHOR_NAME + SPACE_REPO_NAME β {author}/{repo}-telemetry | |
| Override with explicit TELEMETRY_REPO env var. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| from collections import defaultdict, deque | |
| from datetime import UTC, datetime | |
| from typing import Any | |
| from grader import GradeReport | |
| log = logging.getLogger(__name__) | |
| BUFFER_SIZE = 500 | |
| FLUSH_EVERY = 20 | |
| _lock = threading.Lock() | |
| _events: deque[dict[str, Any]] = deque(maxlen=BUFFER_SIZE) | |
| _unflushed: list[dict[str, Any]] = [] | |
| _counters: dict[str, float] = defaultdict(float) | |
| _space_author = os.environ.get("SPACE_AUTHOR_NAME", "") | |
| _space_repo = os.environ.get("SPACE_REPO_NAME", "ai-response-validator") | |
| TELEMETRY_REPO = os.environ.get( | |
| "TELEMETRY_REPO", | |
| f"{_space_author}/{_space_repo}-telemetry" if _space_author else "", | |
| ) | |
| _METRICS = ["pii_leakage", "token_budget", "answer_relevancy", "faithfulness", "chain_terminology"] | |
| def record( | |
| client: str, | |
| domain: str, | |
| query_len: int, | |
| latency_ms: dict[str, float], | |
| report: GradeReport, | |
| docs_retrieved: int, | |
| min_retrieval_score: float, | |
| ) -> None: | |
| """Record one query event. Thread-safe. Flushes to HF dataset in background.""" | |
| event = { | |
| "ts": datetime.now(UTC).isoformat(), | |
| "client": client, | |
| "domain": domain, | |
| "query_len": query_len, | |
| "latency_ms": {k: round(v) for k, v in latency_ms.items()}, | |
| "metrics": {r.metric: round(r.score, 4) for r in report.results}, | |
| "metric_passed": {r.metric: r.passed for r in report.results}, | |
| "overall_pass": report.overall, | |
| "docs_retrieved": docs_retrieved, | |
| "min_retrieval_score": round(min_retrieval_score, 4), | |
| } | |
| with _lock: | |
| _events.append(event) | |
| _unflushed.append(event) | |
| _counters["total"] += 1 | |
| if report.overall: | |
| _counters["overall_pass"] += 1 | |
| for r in report.results: | |
| _counters[f"{r.metric}_total"] += 1 | |
| if r.passed: | |
| _counters[f"{r.metric}_pass"] += 1 | |
| for stage, ms in latency_ms.items(): | |
| _counters[f"lat_{stage}_sum"] += ms | |
| _counters[f"lat_{stage}_n"] += 1 | |
| should_flush = len(_unflushed) >= FLUSH_EVERY | |
| if should_flush and TELEMETRY_REPO: | |
| threading.Thread(target=_flush, daemon=True).start() | |
| def live_stats() -> dict[str, Any]: | |
| """In-memory aggregate for the current session (/metrics endpoint).""" | |
| with _lock: | |
| total = int(_counters.get("total", 0)) | |
| if total == 0: | |
| return {"total_queries": 0, "message": "No queries recorded this session."} | |
| metric_stats = {} | |
| for m in _METRICS: | |
| mt = int(_counters.get(f"{m}_total", 0)) | |
| mp = int(_counters.get(f"{m}_pass", 0)) | |
| metric_stats[m] = { | |
| "pass_rate": round(mp / mt, 3) if mt else None, | |
| "pass_count": mp, | |
| "total": mt, | |
| } | |
| avg_latency = {} | |
| for stage in ("retrieve", "generate", "grade"): | |
| n = _counters.get(f"lat_{stage}_n", 0) | |
| if n: | |
| avg_latency[stage] = round(_counters[f"lat_{stage}_sum"] / n) | |
| return { | |
| "source": "in_memory", | |
| "total_queries": total, | |
| "overall_pass_rate": round(_counters.get("overall_pass", 0) / total, 3), | |
| "metrics": metric_stats, | |
| "avg_latency_ms": avg_latency, | |
| "events_in_buffer": len(_events), | |
| "telemetry_repo": TELEMETRY_REPO or None, | |
| } | |
| def persistent_report() -> dict[str, Any]: | |
| """Aggregate from HF Dataset shards (/report endpoint). Falls back to live_stats.""" | |
| if not TELEMETRY_REPO: | |
| log.info("TELEMETRY_REPO not set β report from in-memory only") | |
| return {"source": "in_memory", **live_stats()} | |
| try: | |
| from huggingface_hub import HfApi | |
| hf_token = os.environ.get("HF_TOKEN") | |
| api = HfApi(token=hf_token) | |
| files = api.list_repo_files(TELEMETRY_REPO, repo_type="dataset") | |
| shard_paths = [f for f in files if f.startswith("events/") and f.endswith(".jsonl")] | |
| if not shard_paths: | |
| return {"source": "hf_dataset", "repo": TELEMETRY_REPO, | |
| "message": "No shards yet β data accumulates after first flush."} | |
| events = [] | |
| for path in shard_paths: | |
| content = api.hf_hub_download( | |
| TELEMETRY_REPO, path, repo_type="dataset", token=hf_token, | |
| ) | |
| with open(content) as f: | |
| for line in f: | |
| if line.strip(): | |
| events.append(json.loads(line)) | |
| if not events: | |
| return {"source": "hf_dataset", "repo": TELEMETRY_REPO, "total_events": 0} | |
| total = len(events) | |
| overall_pass = sum(1 for e in events if e.get("overall_pass")) | |
| metric_stats = {} | |
| for m in _METRICS: | |
| passed = sum(1 for e in events if e.get("metric_passed", {}).get(m)) | |
| scores = [e["metrics"][m] for e in events if m in e.get("metrics", {})] | |
| metric_stats[m] = { | |
| "pass_rate": round(passed / total, 3), | |
| "avg_score": round(sum(scores) / len(scores), 3) if scores else None, | |
| } | |
| client_breakdown: dict[str, dict[str, int]] = defaultdict(lambda: {"total": 0, "pass": 0}) | |
| for e in events: | |
| c = e.get("client", "unknown") | |
| client_breakdown[c]["total"] += 1 | |
| if e.get("overall_pass"): | |
| client_breakdown[c]["pass"] += 1 | |
| return { | |
| "source": "hf_dataset", | |
| "repo": TELEMETRY_REPO, | |
| "total_queries": total, | |
| "overall_pass_rate": round(overall_pass / total, 3), | |
| "first_event": min(e["ts"] for e in events), | |
| "last_event": max(e["ts"] for e in events), | |
| "metrics": metric_stats, | |
| "by_client": { | |
| c: {"total": v["total"], "pass_rate": round(v["pass"] / v["total"], 3)} | |
| for c, v in client_breakdown.items() | |
| }, | |
| "shards_read": len(shard_paths), | |
| } | |
| except Exception as e: | |
| log.warning("HF Dataset report failed (%s) β falling back to in-memory", e) | |
| return {"source": "in_memory_fallback", **live_stats()} | |
| def _flush() -> None: | |
| """Upload buffered events to HF Dataset as a JSONL shard. Runs in background thread.""" | |
| with _lock: | |
| if not _unflushed: | |
| return | |
| batch = list(_unflushed) | |
| _unflushed.clear() | |
| try: | |
| from huggingface_hub import HfApi | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| log.warning("HF_TOKEN not set β telemetry flush skipped") | |
| return | |
| api = HfApi(token=hf_token) | |
| try: | |
| api.create_repo(TELEMETRY_REPO, repo_type="dataset", exist_ok=True, private=False) | |
| except Exception: | |
| pass | |
| ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%f") | |
| content = "\n".join(json.dumps(e) for e in batch).encode() | |
| api.upload_file( | |
| path_or_fileobj=content, | |
| path_in_repo=f"events/shard_{ts}.jsonl", | |
| repo_id=TELEMETRY_REPO, | |
| repo_type="dataset", | |
| ) | |
| log.info("Flushed %d telemetry events to %s", len(batch), TELEMETRY_REPO) | |
| except Exception as e: | |
| log.warning("Telemetry flush failed: %s β events returned to buffer", e) | |
| with _lock: | |
| _unflushed.extend(batch) | |