Spaces:
Running
Running
| """Zero-dependency TF-IDF retriever over the company-memory corpus. | |
| Design choices | |
| -------------- | |
| * Pure Python (re + math) so it runs inside OpenEnv/HF Space containers | |
| without pulling sentence-transformers / faiss / langchain. | |
| * Each markdown file is split into paragraph-level chunks, indexed with | |
| token-frequency + inverse-document-frequency, and queried with cosine | |
| over sparse TF-IDF vectors. | |
| * The retriever is the *single source of grounding truth* consumed by | |
| both the specialists (at rollout time) and the grader (at scoring | |
| time), so the reward becomes verifiable instead of fuzzy. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Sequence, Tuple | |
| _WORD_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9_-]{1,}") | |
| def _tokenize(text: str) -> List[str]: | |
| return [tok.lower() for tok in _WORD_RE.findall(text or "")] | |
| class MemoryHit: | |
| source: str | |
| snippet: str | |
| score: float | |
| def as_citation(self) -> str: | |
| """Stable string form used in ExpertReport.citations and briefs.""" | |
| return f"memory:{self.source}" | |
| class Retriever: | |
| """Lightweight TF-IDF retriever over a directory of markdown files.""" | |
| def __init__(self, corpus_dir: Path) -> None: | |
| self.corpus_dir = Path(corpus_dir) | |
| self._docs: List[Tuple[str, str]] = [] | |
| self._vocab: Dict[str, int] = {} | |
| self._tf: List[Dict[int, float]] = [] | |
| self._df: Dict[int, int] = {} | |
| self._norms: List[float] = [] | |
| self._load() | |
| self._build_index() | |
| # -- indexing ---------------------------------------------------------- | |
| def _load(self) -> None: | |
| for path in sorted(self.corpus_dir.rglob("*.md")): | |
| rel = str(path.relative_to(self.corpus_dir)).replace("\\", "/") | |
| text = path.read_text(encoding="utf-8") | |
| chunks = [chunk.strip() for chunk in re.split(r"\n\s*\n", text) if chunk.strip()] | |
| for i, chunk in enumerate(chunks): | |
| self._docs.append((f"{rel}#chunk{i}", chunk)) | |
| def _build_index(self) -> None: | |
| for _, text in self._docs: | |
| tokens = _tokenize(text) | |
| tf_doc: Dict[int, float] = {} | |
| for tok in tokens: | |
| if tok not in self._vocab: | |
| self._vocab[tok] = len(self._vocab) | |
| idx = self._vocab[tok] | |
| tf_doc[idx] = tf_doc.get(idx, 0.0) + 1.0 | |
| self._tf.append(tf_doc) | |
| for idx in tf_doc: | |
| self._df[idx] = self._df.get(idx, 0) + 1 | |
| self._num_docs = len(self._docs) | |
| # cache L2 norms of tf-idf vectors for cosine similarity. | |
| self._norms = [] | |
| for tf_doc in self._tf: | |
| sq = 0.0 | |
| for idx, tf in tf_doc.items(): | |
| idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0 | |
| sq += (tf * idf) ** 2 | |
| self._norms.append(math.sqrt(sq) or 1.0) | |
| # -- public API -------------------------------------------------------- | |
| def query(self, text: str, k: int = 3) -> List[MemoryHit]: | |
| if not text or not self._docs: | |
| return [] | |
| tokens = _tokenize(text) | |
| if not tokens: | |
| return [] | |
| q_tf: Dict[int, float] = {} | |
| for tok in tokens: | |
| if tok in self._vocab: | |
| idx = self._vocab[tok] | |
| q_tf[idx] = q_tf.get(idx, 0.0) + 1.0 | |
| if not q_tf: | |
| return [] | |
| # query norm | |
| q_sq = 0.0 | |
| for idx, tf in q_tf.items(): | |
| idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0 | |
| q_sq += (tf * idf) ** 2 | |
| q_norm = math.sqrt(q_sq) or 1.0 | |
| scores: List[Tuple[float, int]] = [] | |
| for doc_idx, tf_doc in enumerate(self._tf): | |
| dot = 0.0 | |
| for idx, qt in q_tf.items(): | |
| tfd = tf_doc.get(idx) | |
| if tfd is None: | |
| continue | |
| idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0 | |
| dot += (qt * idf) * (tfd * idf) | |
| if dot <= 0.0: | |
| continue | |
| cosine = dot / (q_norm * self._norms[doc_idx]) | |
| scores.append((cosine, doc_idx)) | |
| scores.sort(reverse=True) | |
| hits: List[MemoryHit] = [] | |
| for score, doc_idx in scores[: max(k, 0)]: | |
| source, chunk = self._docs[doc_idx] | |
| snippet = chunk.replace("\n", " ").strip() | |
| if len(snippet) > 280: | |
| snippet = snippet[:277] + "..." | |
| hits.append(MemoryHit(source=source, snippet=snippet, score=round(float(score), 4))) | |
| return hits | |
| def sources(self) -> List[str]: | |
| return [src for src, _ in self._docs] | |
| def has_source(self, source: str) -> bool: | |
| return any(src == source for src, _ in self._docs) | |
| def contains_any(self, text: str, sources: Sequence[str]) -> bool: | |
| """Return True if any of ``sources`` appears verbatim in ``text``. | |
| Used by the grader to verify citations are grounded in the corpus | |
| rather than hallucinated strings. | |
| """ | |
| if not text: | |
| return False | |
| return any(src and src in text for src in sources) | |
| def count_grounded_citations(self, citations: Iterable[str]) -> int: | |
| known = set(self.sources()) | |
| n = 0 | |
| for citation in citations or []: | |
| if not isinstance(citation, str) or not citation.startswith("memory:"): | |
| continue | |
| if citation[len("memory:"):] in known: | |
| n += 1 | |
| return n | |
| __all__ = ["MemoryHit", "Retriever"] | |