| from __future__ import annotations |
|
|
| import json |
| import os |
| from typing import List |
|
|
| from models import RetrievedChunk |
| from utils import clean_math_text, score_token_overlap |
|
|
| try: |
| import numpy as np |
| except Exception: |
| np = None |
|
|
| try: |
| from sentence_transformers import SentenceTransformer |
| except Exception: |
| SentenceTransformer = None |
|
|
|
|
| class RetrievalEngine: |
| def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"): |
| self.data_path = data_path |
| self.rows = self._load_rows(data_path) |
| self.encoder = None |
| self.embeddings = None |
|
|
| if SentenceTransformer is not None and self.rows: |
| try: |
| self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| self.embeddings = self.encoder.encode( |
| [r["text"] for r in self.rows], |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| ) |
| except Exception: |
| self.encoder = None |
| self.embeddings = None |
|
|
| def _load_rows(self, data_path: str) -> List[dict]: |
| rows: List[dict] = [] |
| if not os.path.exists(data_path): |
| return rows |
|
|
| with open(data_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| item = json.loads(line) |
| except Exception: |
| continue |
|
|
| rows.append( |
| { |
| "text": item.get("text", ""), |
| "topic": ( |
| item.get("topic") |
| or item.get("topic_guess") |
| or item.get("section") |
| or "general" |
| ), |
| "source": ( |
| item.get("source") |
| or item.get("source_name") |
| or item.get("source_file") |
| or "local_corpus" |
| ), |
| } |
| ) |
| return rows |
|
|
| def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float: |
| desired_topic = (desired_topic or "").lower() |
| row_topic = (row_topic or "").lower() |
| intent = (intent or "").lower() |
|
|
| bonus = 0.0 |
|
|
| if desired_topic and desired_topic in row_topic: |
| bonus += 1.25 |
|
|
| if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}: |
| bonus += 1.0 |
|
|
| if desired_topic == "percent" and "percent" in row_topic: |
| bonus += 1.0 |
|
|
| if desired_topic in {"number_theory", "number_properties"} and any( |
| k in row_topic for k in ["number", "divisible", "remainder", "prime", "factor"] |
| ): |
| bonus += 1.0 |
|
|
| if desired_topic == "geometry" and any( |
| k in row_topic for k in ["geometry", "circle", "triangle", "area", "perimeter"] |
| ): |
| bonus += 1.0 |
|
|
| if desired_topic == "probability" and "probability" in row_topic: |
| bonus += 1.0 |
|
|
| if desired_topic == "statistics" and any( |
| k in row_topic for k in ["statistics", "mean", "median", "average", "distribution"] |
| ): |
| bonus += 1.0 |
|
|
| if intent in {"method", "step_by_step", "full_working", "hint", "walkthrough", "instruction"}: |
| if any( |
| k in row_topic |
| for k in [ |
| "algebra", |
| "percent", |
| "fractions", |
| "word_problems", |
| "general", |
| "ratio", |
| "probability", |
| "statistics", |
| ] |
| ): |
| bonus += 0.25 |
|
|
| return bonus |
|
|
| def search( |
| self, |
| query: str, |
| topic: str = "", |
| intent: str = "answer", |
| k: int = 3, |
| ) -> List[RetrievedChunk]: |
|
|
| if not self.rows: |
| return [] |
|
|
| combined_query = clean_math_text(query) |
| normalized_topic = (topic or "").strip().lower() |
|
|
| candidate_rows = self.rows |
| candidate_indices = None |
|
|
| if normalized_topic: |
| exact_topic_rows = [ |
| (i, row) for i, row in enumerate(self.rows) |
| if (row.get("topic") or "").strip().lower() == normalized_topic |
| ] |
|
|
| partial_topic_rows = [ |
| (i, row) for i, row in enumerate(self.rows) |
| if normalized_topic in (row.get("topic") or "").strip().lower() |
| or (row.get("topic") or "").strip().lower() in normalized_topic |
| ] |
|
|
| chosen_rows = exact_topic_rows or partial_topic_rows |
| if chosen_rows: |
| candidate_indices = [i for i, _ in chosen_rows] |
| candidate_rows = [row for _, row in chosen_rows] |
|
|
| scores = [] |
|
|
| if self.encoder is not None and self.embeddings is not None and np is not None: |
| try: |
| q = self.encoder.encode( |
| [combined_query], |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| )[0] |
|
|
| if candidate_indices is None: |
| candidate_embeddings = self.embeddings |
| else: |
| candidate_embeddings = self.embeddings[candidate_indices] |
|
|
| semantic_scores = candidate_embeddings @ q |
|
|
| for row, sem in zip(candidate_rows, semantic_scores.tolist()): |
| lexical = score_token_overlap(combined_query, row["text"]) |
| bonus = self._topic_bonus(topic, row["topic"], intent) |
| total = 0.7 * sem + 0.3 * lexical + bonus |
| scores.append((total, row)) |
| except Exception: |
| scores = [] |
|
|
| if not scores: |
| for row in candidate_rows: |
| lexical = score_token_overlap(combined_query, row["text"]) |
| bonus = self._topic_bonus(topic, row["topic"], intent) |
| scores.append((lexical + bonus, row)) |
|
|
| scores.sort(key=lambda x: x[0], reverse=True) |
|
|
| results: List[RetrievedChunk] = [] |
| for score, row in scores[:k]: |
| results.append( |
| RetrievedChunk( |
| text=row["text"], |
| topic=row["topic"], |
| source=row["source"], |
| score=float(score), |
| ) |
| ) |
|
|
| return results |