AutoDataLab2.0 / memory /retriever.py
uchihamadara1816's picture
Upload 172 files
d02bacd verified
"""Zero-dependency TF-IDF retriever over the company-memory corpus.
Design choices
--------------
* Pure Python (re + math) so it runs inside OpenEnv/HF Space containers
without pulling sentence-transformers / faiss / langchain.
* Each markdown file is split into paragraph-level chunks, indexed with
token-frequency + inverse-document-frequency, and queried with cosine
over sparse TF-IDF vectors.
* The retriever is the *single source of grounding truth* consumed by
both the specialists (at rollout time) and the grader (at scoring
time), so the reward becomes verifiable instead of fuzzy.
"""
from __future__ import annotations
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple
_WORD_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9_-]{1,}")
def _tokenize(text: str) -> List[str]:
return [tok.lower() for tok in _WORD_RE.findall(text or "")]
@dataclass
class MemoryHit:
source: str
snippet: str
score: float
def as_citation(self) -> str:
"""Stable string form used in ExpertReport.citations and briefs."""
return f"memory:{self.source}"
class Retriever:
"""Lightweight TF-IDF retriever over a directory of markdown files."""
def __init__(self, corpus_dir: Path) -> None:
self.corpus_dir = Path(corpus_dir)
self._docs: List[Tuple[str, str]] = []
self._vocab: Dict[str, int] = {}
self._tf: List[Dict[int, float]] = []
self._df: Dict[int, int] = {}
self._norms: List[float] = []
self._load()
self._build_index()
# -- indexing ----------------------------------------------------------
def _load(self) -> None:
for path in sorted(self.corpus_dir.rglob("*.md")):
rel = str(path.relative_to(self.corpus_dir)).replace("\\", "/")
text = path.read_text(encoding="utf-8")
chunks = [chunk.strip() for chunk in re.split(r"\n\s*\n", text) if chunk.strip()]
for i, chunk in enumerate(chunks):
self._docs.append((f"{rel}#chunk{i}", chunk))
def _build_index(self) -> None:
for _, text in self._docs:
tokens = _tokenize(text)
tf_doc: Dict[int, float] = {}
for tok in tokens:
if tok not in self._vocab:
self._vocab[tok] = len(self._vocab)
idx = self._vocab[tok]
tf_doc[idx] = tf_doc.get(idx, 0.0) + 1.0
self._tf.append(tf_doc)
for idx in tf_doc:
self._df[idx] = self._df.get(idx, 0) + 1
self._num_docs = len(self._docs)
# cache L2 norms of tf-idf vectors for cosine similarity.
self._norms = []
for tf_doc in self._tf:
sq = 0.0
for idx, tf in tf_doc.items():
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
sq += (tf * idf) ** 2
self._norms.append(math.sqrt(sq) or 1.0)
# -- public API --------------------------------------------------------
def query(self, text: str, k: int = 3) -> List[MemoryHit]:
if not text or not self._docs:
return []
tokens = _tokenize(text)
if not tokens:
return []
q_tf: Dict[int, float] = {}
for tok in tokens:
if tok in self._vocab:
idx = self._vocab[tok]
q_tf[idx] = q_tf.get(idx, 0.0) + 1.0
if not q_tf:
return []
# query norm
q_sq = 0.0
for idx, tf in q_tf.items():
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
q_sq += (tf * idf) ** 2
q_norm = math.sqrt(q_sq) or 1.0
scores: List[Tuple[float, int]] = []
for doc_idx, tf_doc in enumerate(self._tf):
dot = 0.0
for idx, qt in q_tf.items():
tfd = tf_doc.get(idx)
if tfd is None:
continue
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
dot += (qt * idf) * (tfd * idf)
if dot <= 0.0:
continue
cosine = dot / (q_norm * self._norms[doc_idx])
scores.append((cosine, doc_idx))
scores.sort(reverse=True)
hits: List[MemoryHit] = []
for score, doc_idx in scores[: max(k, 0)]:
source, chunk = self._docs[doc_idx]
snippet = chunk.replace("\n", " ").strip()
if len(snippet) > 280:
snippet = snippet[:277] + "..."
hits.append(MemoryHit(source=source, snippet=snippet, score=round(float(score), 4)))
return hits
def sources(self) -> List[str]:
return [src for src, _ in self._docs]
def has_source(self, source: str) -> bool:
return any(src == source for src, _ in self._docs)
def contains_any(self, text: str, sources: Sequence[str]) -> bool:
"""Return True if any of ``sources`` appears verbatim in ``text``.
Used by the grader to verify citations are grounded in the corpus
rather than hallucinated strings.
"""
if not text:
return False
return any(src and src in text for src in sources)
def count_grounded_citations(self, citations: Iterable[str]) -> int:
known = set(self.sources())
n = 0
for citation in citations or []:
if not isinstance(citation, str) or not citation.startswith("memory:"):
continue
if citation[len("memory:"):] in known:
n += 1
return n
__all__ = ["MemoryHit", "Retriever"]