Spaces:
Running
Running
| """CLI: Pooled retrieval annotation tool (TREC-style). | |
| Runs BM25, vector, and hybrid retrieval on each question, pools unique | |
| chunks, and presents them for human relevance judgment. | |
| Usage: | |
| python scripts/annotate.py | |
| python scripts/annotate.py --question "What is LoRA?" | |
| python scripts/annotate.py --top-k 10 | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import PROJECT_ROOT, get_config | |
| from src.ingestion.embeddings import EmbeddingGenerator | |
| from src.retrieval.bm25_retriever import BM25Retriever | |
| from src.retrieval.hybrid_retriever import HybridRetriever | |
| from src.retrieval.vector_retriever import VectorRetriever | |
| from src.storage.chroma_store import ChromaStore | |
| from src.storage.sqlite_db import SQLiteDB | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| QUESTIONS_PATH = PROJECT_ROOT / "data" / "questions.json" | |
| EVAL_SET_PATH = PROJECT_ROOT / "data" / "eval_set.json" | |
| # ββ Retrieval helpers ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def init_retrievers(config): | |
| """Initialize DB, BM25, vector, and hybrid retrievers.""" | |
| db = SQLiteDB(config.sqlite_db_path) | |
| chroma = ChromaStore(config.chroma_db_path) | |
| embed_gen = EmbeddingGenerator(config.embedding_model) | |
| bm25 = BM25Retriever(db) | |
| bm25.build_index() | |
| vector = VectorRetriever(chroma, embed_gen) | |
| hybrid = HybridRetriever(bm25, vector) | |
| return db, bm25, vector, hybrid | |
| def run_pooled_retrieval( | |
| question: str, | |
| bm25: BM25Retriever, | |
| vector: VectorRetriever, | |
| hybrid: HybridRetriever, | |
| top_k: int = 10, | |
| ) -> tuple[list[dict], dict[str, list]]: | |
| """Run all three methods and pool unique chunks. | |
| Returns: | |
| (pooled_chunks, method_pools) where pooled_chunks has unique chunk_ids | |
| and method_pools maps method name β list of chunk_ids retrieved. | |
| """ | |
| bm25_results = bm25.search(question, top_k=top_k) | |
| vector_results = vector.search(question, top_k=top_k) | |
| hybrid_results = hybrid.search(question, top_k=top_k) | |
| bm25_ids = [r["chunk_id"] for r in bm25_results] | |
| vector_ids = [r["chunk_id"] for r in vector_results] | |
| hybrid_ids = [r["chunk_id"] for r in hybrid_results] | |
| method_pools = { | |
| "bm25_top10": bm25_ids, | |
| "vector_top10": vector_ids, | |
| "hybrid_top10": hybrid_ids, | |
| } | |
| # Deduplicate, preserving first-seen order | |
| seen = set() | |
| pooled_ids = [] | |
| for cid in bm25_ids + vector_ids + hybrid_ids: | |
| if cid not in seen: | |
| seen.add(cid) | |
| pooled_ids.append(cid) | |
| return pooled_ids, method_pools | |
| def resolve_chunks(db: SQLiteDB, chunk_ids: list) -> dict: | |
| """Look up chunk + paper metadata for each chunk ID.""" | |
| all_chunks = db.get_all_chunks() | |
| chunk_map = {c["id"]: c for c in all_chunks} | |
| resolved = {} | |
| for cid in chunk_ids: | |
| chunk = chunk_map.get(cid) | |
| if chunk is not None: | |
| resolved[cid] = chunk | |
| return resolved | |
| def chunk_methods(chunk_id, method_pools: dict) -> list[str]: | |
| """Which methods retrieved this chunk.""" | |
| methods = [] | |
| for method, ids in method_pools.items(): | |
| if chunk_id in ids: | |
| methods.append(method.split("_")[0]) # "bm25", "vector", "hybrid" | |
| return methods | |
| # ββ Annotation state βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_eval_set() -> list[dict]: | |
| if EVAL_SET_PATH.exists(): | |
| with open(EVAL_SET_PATH, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| def save_eval_set(eval_set: list[dict]) -> None: | |
| EVAL_SET_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with open(EVAL_SET_PATH, "w", encoding="utf-8") as f: | |
| json.dump(eval_set, f, indent=2, ensure_ascii=False) | |
| def find_entry(eval_set: list[dict], question_id: str) -> dict | None: | |
| for entry in eval_set: | |
| if entry.get("id") == question_id: | |
| return entry | |
| return None | |
| # ββ Display helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def truncate(text: str, max_len: int = 500) -> str: | |
| if len(text) <= max_len: | |
| return text | |
| return text[:max_len] + "..." | |
| def display_chunk( | |
| idx: int, | |
| total: int, | |
| chunk: dict, | |
| methods: list[str], | |
| show_full: bool = False, | |
| ): | |
| """Print a single chunk for annotation.""" | |
| print(f"\n{'='*60}") | |
| print(f" [{idx}/{total}] Methods: {', '.join(methods)}") | |
| print(f" Title: {chunk.get('title', '?')}") | |
| print(f" Venue: {chunk.get('venue', '?')} | Year: {chunk.get('year', '?')}") | |
| print(f" Paper ID: {chunk.get('paper_id', '?')}") | |
| print(f" Chunk ID: {chunk.get('id', '?')} | Type: {chunk.get('chunk_type', '?')}") | |
| print(f"{'β'*60}") | |
| text = chunk.get("chunk_text", "") | |
| if show_full or len(text) <= 500: | |
| print(text) | |
| else: | |
| print(truncate(text, 500)) | |
| print(f" [{len(text)} chars total β press 'f' to show full]") | |
| print(f"{'β'*60}") | |
| print(" (y) relevant (n) not relevant (s) skip (f) full text (q) quit") | |
| # ββ Annotation loop ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def annotate_question( | |
| question: dict, | |
| db: SQLiteDB, | |
| bm25: BM25Retriever, | |
| vector: VectorRetriever, | |
| hybrid: HybridRetriever, | |
| eval_set: list[dict], | |
| top_k: int = 10, | |
| ) -> bool: | |
| """Annotate one question. Returns False if user quit.""" | |
| qid = question["id"] | |
| qtext = question["question"] | |
| print(f"\n{'#'*60}") | |
| print(f" Question [{qid}]: {qtext}") | |
| print(f" Type: {question.get('type', '?')}") | |
| kw = question.get("expected_keywords", []) | |
| if kw: | |
| print(f" Expected keywords: {', '.join(kw)}") | |
| print(f"{'#'*60}") | |
| print("\nRunning retrieval (BM25 + vector + hybrid)...") | |
| pooled_ids, method_pools = run_pooled_retrieval( | |
| qtext, bm25, vector, hybrid, top_k=top_k, | |
| ) | |
| resolved = resolve_chunks(db, pooled_ids) | |
| ordered_ids = [cid for cid in pooled_ids if cid in resolved] | |
| total = len(ordered_ids) | |
| if total == 0: | |
| print(" No chunks retrieved. Skipping.") | |
| return True | |
| print(f"\nPooled {total} unique chunks from {top_k * 3} candidates.\n") | |
| relevant_chunk_ids = [] | |
| irrelevant_chunk_ids = [] | |
| skipped_chunk_ids = [] | |
| i = 0 | |
| while i < total: | |
| cid = ordered_ids[i] | |
| chunk = resolved[cid] | |
| methods = chunk_methods(cid, method_pools) | |
| display_chunk(i + 1, total, chunk, methods, show_full=False) | |
| action = input(" > ").strip().lower() | |
| if action == "y": | |
| relevant_chunk_ids.append(cid) | |
| i += 1 | |
| elif action == "n": | |
| irrelevant_chunk_ids.append(cid) | |
| i += 1 | |
| elif action == "s": | |
| skipped_chunk_ids.append(cid) | |
| i += 1 | |
| elif action == "f": | |
| display_chunk(i + 1, total, chunk, methods, show_full=True) | |
| # Don't advance β let user judge after seeing full text | |
| elif action == "q": | |
| # Save partial progress before quitting | |
| _save_annotation( | |
| eval_set, question, method_pools, | |
| relevant_chunk_ids, irrelevant_chunk_ids, skipped_chunk_ids, | |
| resolved, | |
| ) | |
| return False | |
| else: | |
| print(" Invalid. Use y/n/s/f/q.") | |
| # Save completed annotation | |
| _save_annotation( | |
| eval_set, question, method_pools, | |
| relevant_chunk_ids, irrelevant_chunk_ids, skipped_chunk_ids, | |
| resolved, | |
| ) | |
| return True | |
| def _save_annotation( | |
| eval_set: list[dict], | |
| question: dict, | |
| method_pools: dict, | |
| relevant_chunk_ids: list, | |
| irrelevant_chunk_ids: list, | |
| skipped_chunk_ids: list, | |
| resolved: dict, | |
| ): | |
| """Build and save the annotation entry.""" | |
| # Derive relevant paper IDs from relevant chunks | |
| relevant_paper_ids = list(dict.fromkeys( | |
| resolved[cid]["paper_id"] | |
| for cid in relevant_chunk_ids | |
| if cid in resolved | |
| )) | |
| # Stringify chunk IDs for JSON | |
| def to_str_ids(ids): | |
| return [str(x) for x in ids] | |
| entry = { | |
| "id": question["id"], | |
| "question": question["question"], | |
| "type": question.get("type", ""), | |
| "expected_keywords": question.get("expected_keywords", []), | |
| "relevant_chunk_ids": to_str_ids(relevant_chunk_ids), | |
| "irrelevant_chunk_ids": to_str_ids(irrelevant_chunk_ids), | |
| "skipped_chunk_ids": to_str_ids(skipped_chunk_ids), | |
| "relevant_paper_ids": relevant_paper_ids, | |
| "pooled_from": { | |
| k: to_str_ids(v) for k, v in method_pools.items() | |
| }, | |
| "annotated_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| # Replace or append | |
| existing = find_entry(eval_set, question["id"]) | |
| if existing is not None: | |
| idx = eval_set.index(existing) | |
| eval_set[idx] = entry | |
| else: | |
| eval_set.append(entry) | |
| save_eval_set(eval_set) | |
| n_rel = len(relevant_chunk_ids) | |
| n_irr = len(irrelevant_chunk_ids) | |
| n_skip = len(skipped_chunk_ids) | |
| print(f"\n Saved: {n_rel} relevant, {n_irr} irrelevant, {n_skip} skipped") | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Pooled retrieval annotation (TREC-style)" | |
| ) | |
| parser.add_argument( | |
| "--question", type=str, default=None, | |
| help="Annotate a single ad-hoc question (bypasses questions.json)", | |
| ) | |
| parser.add_argument("--top-k", type=int, default=10, help="Results per method") | |
| parser.add_argument( | |
| "--force", action="store_true", | |
| help="Re-annotate questions that already have judgments", | |
| ) | |
| args = parser.parse_args() | |
| config = get_config() | |
| db, bm25, vector, hybrid = init_retrievers(config) | |
| eval_set = load_eval_set() | |
| if args.question: | |
| # Ad-hoc single question | |
| q = { | |
| "id": f"adhoc_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}", | |
| "question": args.question, | |
| "type": "factual", | |
| "expected_keywords": [], | |
| } | |
| annotate_question(q, db, bm25, vector, hybrid, eval_set, top_k=args.top_k) | |
| return | |
| # Load questions from file | |
| if not QUESTIONS_PATH.exists(): | |
| print(f"No questions file found at {QUESTIONS_PATH}") | |
| print("Run: python scripts/write_questions.py") | |
| sys.exit(1) | |
| with open(QUESTIONS_PATH, encoding="utf-8") as f: | |
| questions = json.load(f) | |
| if not questions: | |
| print("Questions file is empty. Add questions first.") | |
| sys.exit(1) | |
| annotated_ids = {e["id"] for e in eval_set} | |
| for q in questions: | |
| qid = q["id"] | |
| if qid in annotated_ids and not args.force: | |
| print(f"\n [{qid}] already annotated β skipping (use --force to redo)") | |
| continue | |
| if not annotate_question(q, db, bm25, vector, hybrid, eval_set, top_k=args.top_k): | |
| print("\nAnnotation paused. Progress saved.") | |
| break | |
| print(f"\nDone. {len(eval_set)} entries in {EVAL_SET_PATH}") | |
| if __name__ == "__main__": | |
| main() | |