| """ |
| Contextual Word Similarity Engine |
| |
| Uses transformer-based sentence embeddings (SentenceTransformers) and FAISS |
| vector search to find and compare contextual meanings of keywords within |
| large documents. Unlike static embeddings (Word2Vec/GloVe), this captures |
| how word meaning changes based on surrounding context. |
| |
| Usage: |
| engine = ContextualSimilarityEngine() |
| engine.add_document("my_doc", text) |
| engine.build_index() |
| results = engine.analyze_keyword("pizza", top_k=10) |
| """ |
|
|
| import re |
| import logging |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import faiss |
| import numpy as np |
| from sentence_transformers import SentenceTransformer, util |
| from sklearn.cluster import AgglomerativeClustering |
| from tqdm import tqdm |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class Chunk: |
| """A passage of text from a document with metadata.""" |
| text: str |
| doc_id: str |
| chunk_index: int |
| start_char: int |
| end_char: int |
|
|
| def __repr__(self): |
| preview = self.text[:80].replace("\n", " ") |
| return f"Chunk(doc={self.doc_id!r}, idx={self.chunk_index}, text={preview!r}...)" |
|
|
|
|
| @dataclass |
| class SimilarityResult: |
| """A single similarity match.""" |
| chunk: Chunk |
| score: float |
| rank: int |
|
|
|
|
| @dataclass |
| class KeywordContext: |
| """A keyword occurrence with its surrounding context and embedding.""" |
| keyword: str |
| chunk: Chunk |
| highlight_positions: list = field(default_factory=list) |
|
|
|
|
| @dataclass |
| class KeywordAnalysis: |
| """Full analysis of a keyword's contextual meanings across a corpus.""" |
| keyword: str |
| total_occurrences: int |
| meaning_clusters: list = field(default_factory=list) |
| cross_keyword_similarities: dict = field(default_factory=dict) |
|
|
|
|
| class ContextualSimilarityEngine: |
| """ |
| Engine for contextual word similarity analysis using transformer embeddings. |
| |
| Loads documents, chunks them into passages, embeds with a SentenceTransformer |
| model, indexes with FAISS, and provides methods to: |
| - Find all contextual usages of a keyword |
| - Cluster keyword usages into distinct meanings |
| - Compare keyword contexts across documents |
| - Find passages most similar to a query |
| - Batch-analyze multiple keywords |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "all-MiniLM-L6-v2", |
| chunk_size: int = 512, |
| chunk_overlap: int = 128, |
| device: Optional[str] = None, |
| batch_size: int = 64, |
| ): |
| """ |
| Args: |
| model_name: HuggingFace SentenceTransformer model name. |
| - "all-MiniLM-L6-v2": fast, good quality (384-dim) |
| - "all-mpnet-base-v2": best quality general-purpose (768-dim) |
| - "BAAI/bge-large-en-v1.5": high accuracy, larger (1024-dim) |
| chunk_size: Max characters per chunk. |
| chunk_overlap: Overlap between consecutive chunks (preserves context at boundaries). |
| device: PyTorch device ("cpu", "cuda", "mps"). Auto-detected if None. |
| batch_size: Batch size for encoding (tune for your GPU memory). |
| """ |
| logger.info(f"Loading model: {model_name}") |
| self._model_name = model_name |
| self.model = SentenceTransformer(model_name, device=device) |
| self.chunk_size = chunk_size |
| self.chunk_overlap = chunk_overlap |
| self.batch_size = batch_size |
| self.embedding_dim = self.model.get_sentence_embedding_dimension() |
|
|
| |
| self.chunks: list[Chunk] = [] |
| self.embeddings: Optional[np.ndarray] = None |
| self.index: Optional[faiss.IndexFlatIP] = None |
| self._doc_ids: set[str] = set() |
|
|
| |
| |
| |
|
|
| def add_document(self, doc_id: str, text: str) -> list[Chunk]: |
| """ |
| Chunk a document and add it to the corpus. |
| |
| Args: |
| doc_id: Unique identifier for this document. |
| text: Full document text. |
| |
| Returns: |
| List of Chunk objects created from this document. |
| """ |
| if doc_id in self._doc_ids: |
| raise ValueError(f"Document '{doc_id}' already added. Use a unique doc_id.") |
| self._doc_ids.add(doc_id) |
|
|
| new_chunks = self._chunk_text(text, doc_id) |
| self.chunks.extend(new_chunks) |
| logger.info(f"Added document '{doc_id}': {len(new_chunks)} chunks") |
|
|
| |
| self.embeddings = None |
| self.index = None |
|
|
| return new_chunks |
|
|
| def add_document_from_file(self, file_path: str, doc_id: Optional[str] = None) -> list[Chunk]: |
| """Load a text file and add it as a document.""" |
| path = Path(file_path).resolve() |
| base_dir = Path(__file__).parent.resolve() |
| if not path.is_relative_to(base_dir): |
| raise ValueError("File path must be within the project directory.") |
| if not path.exists(): |
| raise FileNotFoundError(f"File not found: {file_path}") |
| text = path.read_text(encoding="utf-8") |
| return self.add_document(doc_id or path.stem, text) |
|
|
| def _chunk_text(self, text: str, doc_id: str) -> list[Chunk]: |
| """ |
| Split text into overlapping chunks, breaking at sentence boundaries |
| when possible to preserve semantic coherence. |
| """ |
| |
| text = re.sub(r"\n{3,}", "\n\n", text) |
|
|
| chunks = [] |
| start = 0 |
| chunk_idx = 0 |
|
|
| while start < len(text): |
| end = start + self.chunk_size |
|
|
| |
| if end < len(text): |
| |
| search_region = text[max(end - 100, start):end] |
| |
| for sep in [". ", ".\n", "! ", "!\n", "? ", "?\n", "\n\n"]: |
| last_break = search_region.rfind(sep) |
| if last_break != -1: |
| end = max(end - 100, start) + last_break + len(sep) |
| break |
|
|
| chunk_text = text[start:end].strip() |
| if chunk_text: |
| chunks.append(Chunk( |
| text=chunk_text, |
| doc_id=doc_id, |
| chunk_index=chunk_idx, |
| start_char=start, |
| end_char=end, |
| )) |
| chunk_idx += 1 |
|
|
| |
| start = end - self.chunk_overlap if end < len(text) else end |
|
|
| return chunks |
|
|
| |
| |
| |
|
|
| def build_index(self, normalize: bool = True, show_progress: bool = True) -> None: |
| """ |
| Embed all chunks and build a FAISS index for fast similarity search. |
| |
| Args: |
| normalize: L2-normalize embeddings (enables cosine similarity via inner product). |
| show_progress: Show a progress bar during encoding. |
| """ |
| if not self.chunks: |
| raise RuntimeError("No documents loaded. Call add_document() first.") |
|
|
| logger.info(f"Encoding {len(self.chunks)} chunks...") |
| texts = [c.text for c in self.chunks] |
|
|
| self.embeddings = self.model.encode( |
| texts, |
| batch_size=self.batch_size, |
| show_progress_bar=show_progress, |
| convert_to_numpy=True, |
| normalize_embeddings=normalize, |
| ) |
|
|
| |
| self.index = faiss.IndexFlatIP(self.embedding_dim) |
| self.index.add(self.embeddings.astype(np.float32)) |
|
|
| logger.info(f"Index built: {self.index.ntotal} vectors, dim={self.embedding_dim}") |
|
|
| |
| |
| |
|
|
| def query(self, text: str, top_k: int = 10) -> list[SimilarityResult]: |
| """ |
| Find the most similar chunks to a query text. |
| |
| Args: |
| text: Query string (sentence, phrase, or keyword in context). |
| top_k: Number of results to return. |
| |
| Returns: |
| List of SimilarityResult sorted by descending similarity score. |
| """ |
| self._ensure_index() |
|
|
| query_vec = self.model.encode( |
| [text], normalize_embeddings=True, convert_to_numpy=True |
| ).astype(np.float32) |
|
|
| scores, indices = self.index.search(query_vec, top_k) |
|
|
| results = [] |
| for rank, (score, idx) in enumerate(zip(scores[0], indices[0])): |
| if idx == -1: |
| continue |
| results.append(SimilarityResult( |
| chunk=self.chunks[idx], |
| score=float(score), |
| rank=rank + 1, |
| )) |
| return results |
|
|
| def compare_texts(self, text_a: str, text_b: str) -> float: |
| """ |
| Compute cosine similarity between two texts directly. |
| |
| Returns: |
| Similarity score in [-1, 1] (typically [0, 1] for natural language). |
| """ |
| vecs = self.model.encode( |
| [text_a, text_b], normalize_embeddings=True, convert_to_tensor=True |
| ) |
| return float(util.pytorch_cos_sim(vecs[0], vecs[1]).item()) |
|
|
| |
| |
| |
|
|
| def find_keyword_contexts( |
| self, keyword: str, case_sensitive: bool = False |
| ) -> list[KeywordContext]: |
| """ |
| Find all chunks containing a keyword and return them as KeywordContext objects. |
| |
| Args: |
| keyword: The word or phrase to search for. |
| case_sensitive: Whether matching is case-sensitive. |
| |
| Returns: |
| List of KeywordContext with chunk and highlight positions. |
| """ |
| if len(keyword) > 200: |
| raise ValueError("Keyword must be 200 characters or fewer.") |
| flags = 0 if case_sensitive else re.IGNORECASE |
| pattern = re.compile(r"\b" + re.escape(keyword) + r"\b", flags) |
|
|
| contexts = [] |
| for chunk in self.chunks: |
| matches = list(pattern.finditer(chunk.text)) |
| if matches: |
| positions = [(m.start(), m.end()) for m in matches] |
| contexts.append(KeywordContext( |
| keyword=keyword, |
| chunk=chunk, |
| highlight_positions=positions, |
| )) |
| return contexts |
|
|
| def analyze_keyword( |
| self, |
| keyword: str, |
| top_k: int = 10, |
| cluster_threshold: float = 0.35, |
| case_sensitive: bool = False, |
| ) -> KeywordAnalysis: |
| """ |
| Analyze all contextual usages of a keyword across the corpus. |
| |
| Finds every chunk containing the keyword, embeds them, clusters them |
| by semantic similarity (agglomerative clustering), and returns a |
| structured analysis with distinct meaning groups. |
| |
| Args: |
| keyword: Word or phrase to analyze. |
| top_k: Max similar chunks to return per meaning cluster. |
| cluster_threshold: Distance threshold for clustering (lower = more clusters). |
| 0.35 works well for clearly distinct meanings; raise to 0.5+ to merge similar ones. |
| case_sensitive: Whether keyword matching is case-sensitive. |
| |
| Returns: |
| KeywordAnalysis with meaning clusters and similarity info. |
| """ |
| self._ensure_index() |
| contexts = self.find_keyword_contexts(keyword, case_sensitive) |
|
|
| if not contexts: |
| return KeywordAnalysis(keyword=keyword, total_occurrences=0) |
|
|
| |
| chunk_indices = [] |
| for ctx in contexts: |
| idx = self.chunks.index(ctx.chunk) |
| chunk_indices.append(idx) |
|
|
| kw_embeddings = self.embeddings[chunk_indices] |
|
|
| |
| clusters = self._cluster_embeddings(kw_embeddings, threshold=cluster_threshold) |
|
|
| |
| meaning_clusters = [] |
| for cluster_id in sorted(set(clusters)): |
| member_indices = [i for i, c in enumerate(clusters) if c == cluster_id] |
| member_contexts = [contexts[i] for i in member_indices] |
| member_embeds = kw_embeddings[member_indices] |
|
|
| |
| centroid = member_embeds.mean(axis=0, keepdims=True).astype(np.float32) |
| faiss.normalize_L2(centroid) |
|
|
| |
| scores, idx_arr = self.index.search(centroid, top_k) |
| similar = [] |
| for rank, (score, idx) in enumerate(zip(scores[0], idx_arr[0])): |
| if idx == -1: |
| continue |
| similar.append(SimilarityResult( |
| chunk=self.chunks[idx], |
| score=float(score), |
| rank=rank + 1, |
| )) |
|
|
| meaning_clusters.append({ |
| "cluster_id": cluster_id, |
| "size": len(member_indices), |
| "representative_text": member_contexts[0].chunk.text[:200], |
| "contexts": member_contexts, |
| "similar_passages": similar, |
| }) |
|
|
| return KeywordAnalysis( |
| keyword=keyword, |
| total_occurrences=len(contexts), |
| meaning_clusters=meaning_clusters, |
| ) |
|
|
| def batch_analyze_keywords( |
| self, |
| keywords: list[str], |
| top_k: int = 10, |
| cluster_threshold: float = 0.35, |
| compare_across: bool = True, |
| ) -> dict[str, KeywordAnalysis]: |
| """ |
| Analyze multiple keywords and optionally compute cross-keyword similarities. |
| |
| Args: |
| keywords: List of keywords to analyze. |
| top_k: Results per cluster. |
| cluster_threshold: Clustering distance threshold. |
| compare_across: If True, compute pairwise similarity between keyword contexts. |
| |
| Returns: |
| Dict mapping keyword -> KeywordAnalysis. |
| """ |
| results = {} |
| for kw in tqdm(keywords, desc="Analyzing keywords"): |
| results[kw] = self.analyze_keyword(kw, top_k, cluster_threshold) |
|
|
| if compare_across and len(keywords) > 1: |
| self._compute_cross_keyword_similarities(results) |
|
|
| return results |
|
|
| def _compute_cross_keyword_similarities( |
| self, analyses: dict[str, KeywordAnalysis] |
| ) -> None: |
| """Compute average cosine similarity between each pair of keywords' contexts.""" |
| keyword_centroids = {} |
| for kw, analysis in analyses.items(): |
| if not analysis.meaning_clusters: |
| continue |
| |
| all_indices = [] |
| for cluster in analysis.meaning_clusters: |
| for ctx in cluster["contexts"]: |
| idx = self.chunks.index(ctx.chunk) |
| all_indices.append(idx) |
| if all_indices: |
| embeds = self.embeddings[all_indices] |
| centroid = embeds.mean(axis=0) |
| norm = np.linalg.norm(centroid) |
| if norm > 0: |
| centroid = centroid / norm |
| keyword_centroids[kw] = centroid |
|
|
| |
| kw_list = list(keyword_centroids.keys()) |
| for i, kw_a in enumerate(kw_list): |
| sims = {} |
| for j, kw_b in enumerate(kw_list): |
| if i != j: |
| score = float(np.dot(keyword_centroids[kw_a], keyword_centroids[kw_b])) |
| sims[kw_b] = score |
| if kw_a in analyses: |
| analyses[kw_a].cross_keyword_similarities = sims |
|
|
| |
| |
| |
|
|
| def match_keyword_to_meaning( |
| self, |
| keyword: str, |
| candidate_meanings: list[str], |
| ) -> list[dict]: |
| """ |
| Given a keyword and a list of candidate meanings (words/phrases), |
| find which meaning each occurrence of the keyword is closest to. |
| |
| This is the core "pizza means school" use case: you provide the keyword |
| "pizza" and candidates ["pizza (food)", "school", "homework"], and this |
| method tells you which meaning each usage of "pizza" maps to. |
| |
| Args: |
| keyword: The keyword to analyze (e.g. "pizza"). |
| candidate_meanings: List of meaning descriptions (e.g. ["food", "school"]). |
| |
| Returns: |
| List of dicts with keys: chunk, best_match, scores (all candidates). |
| """ |
| self._ensure_index() |
|
|
| contexts = self.find_keyword_contexts(keyword) |
| if not contexts: |
| return [] |
|
|
| |
| candidate_vecs = self.model.encode( |
| candidate_meanings, normalize_embeddings=True, convert_to_tensor=True |
| ) |
|
|
| results = [] |
| for ctx in contexts: |
| |
| chunk_vec = self.model.encode( |
| [ctx.chunk.text], normalize_embeddings=True, convert_to_tensor=True |
| ) |
|
|
| |
| scores = util.pytorch_cos_sim(chunk_vec, candidate_vecs)[0] |
| score_dict = { |
| meaning: float(scores[i]) for i, meaning in enumerate(candidate_meanings) |
| } |
| best = max(score_dict, key=score_dict.get) |
|
|
| results.append({ |
| "chunk": ctx.chunk, |
| "best_match": best, |
| "best_score": score_dict[best], |
| "all_scores": score_dict, |
| }) |
|
|
| return results |
|
|
| |
| |
| |
|
|
| |
| _STOPWORDS = frozenset( |
| "a an the and or but in on at to for of is it that this was were be been " |
| "being have has had do does did will would shall should may might can could " |
| "not no nor so if then than too very just about above after again all also " |
| "am are as between both by each few from further get got he her here hers " |
| "herself him himself his how i its itself me more most my myself no nor " |
| "only other our ours ourselves out over own same she some such their theirs " |
| "them themselves there these they those through under until up us we what " |
| "when where which while who whom why with you your yours yourself yourselves " |
| "one two three four five six seven eight nine ten into been being because " |
| "during before between against without within along across behind since " |
| "upon around among".split() |
| ) |
|
|
| def infer_keyword_meanings( |
| self, |
| keyword: str, |
| context_window: int = 120, |
| top_words: int = 8, |
| cluster_threshold: float = 0.35, |
| max_meanings: int = 10, |
| ) -> dict: |
| """ |
| Infer what a keyword likely means based on its surrounding context words. |
| |
| Finds all occurrences, clusters them by semantic similarity, then extracts |
| the most distinctive co-occurring words for each meaning cluster. |
| |
| Args: |
| keyword: The keyword to analyze. |
| context_window: Characters around each keyword occurrence to examine. |
| top_words: Number of associated words to return per meaning. |
| cluster_threshold: Distance threshold for clustering. |
| max_meanings: Maximum number of meaning clusters to return. |
| |
| Returns: |
| Dict with keyword, total_occurrences, and meanings list. |
| """ |
| self._ensure_index() |
| contexts = self.find_keyword_contexts(keyword) |
|
|
| if not contexts: |
| return { |
| "keyword": keyword, |
| "total_occurrences": 0, |
| "meanings": [], |
| } |
|
|
| |
| chunk_indices = [self.chunks.index(ctx.chunk) for ctx in contexts] |
| kw_embeddings = self.embeddings[chunk_indices] |
| clusters = self._cluster_embeddings(kw_embeddings, threshold=cluster_threshold) |
|
|
| total = len(contexts) |
| kw_lower = keyword.lower() |
| word_pattern = re.compile(r"[a-zA-Z]{3,}") |
|
|
| |
| global_word_counts: dict[str, int] = {} |
| cluster_data: dict[int, list[dict[str, int]]] = {} |
|
|
| for i, ctx in enumerate(contexts): |
| cluster_id = clusters[i] |
| if cluster_id not in cluster_data: |
| cluster_data[cluster_id] = [] |
|
|
| |
| local_counts: dict[str, int] = {} |
| for start, end in ctx.highlight_positions: |
| window_start = max(0, start - context_window) |
| window_end = min(len(ctx.chunk.text), end + context_window) |
| window_text = ctx.chunk.text[window_start:window_end].lower() |
|
|
| for word_match in word_pattern.finditer(window_text): |
| w = word_match.group() |
| if w == kw_lower or w in self._STOPWORDS or len(w) < 3: |
| continue |
| local_counts[w] = local_counts.get(w, 0) + 1 |
| global_word_counts[w] = global_word_counts.get(w, 0) + 1 |
|
|
| cluster_data[cluster_id].append(local_counts) |
|
|
| |
| meanings = [] |
| for cluster_id in sorted(cluster_data.keys()): |
| members = cluster_data[cluster_id] |
| count = len(members) |
| confidence = round(count / total, 3) |
|
|
| |
| cluster_word_counts: dict[str, int] = {} |
| for member_counts in members: |
| for w, c in member_counts.items(): |
| cluster_word_counts[w] = cluster_word_counts.get(w, 0) + c |
|
|
| |
| |
| num_clusters = len(cluster_data) |
| word_scores: dict[str, float] = {} |
| for w, cluster_count in cluster_word_counts.items(): |
| global_count = global_word_counts.get(w, 1) |
| |
| tf = cluster_count / max(sum(cluster_word_counts.values()), 1) |
| distinctiveness = (cluster_count / global_count) if num_clusters > 1 else 1.0 |
| word_scores[w] = tf * (0.5 + 0.5 * distinctiveness) |
|
|
| |
| sorted_words = sorted(word_scores.items(), key=lambda x: -x[1])[:top_words] |
| associated_words = [ |
| {"word": w, "score": round(s, 4)} for w, s in sorted_words |
| ] |
|
|
| |
| example_contexts = [] |
| member_indices = [j for j, c in enumerate(clusters) if c == cluster_id] |
| for j in member_indices[:3]: |
| ctx = contexts[j] |
| if ctx.highlight_positions: |
| start, end = ctx.highlight_positions[0] |
| snippet_start = max(0, start - 80) |
| snippet_end = min(len(ctx.chunk.text), end + 80) |
| snippet = ctx.chunk.text[snippet_start:snippet_end].strip() |
| if snippet_start > 0: |
| snippet = "..." + snippet |
| if snippet_end < len(ctx.chunk.text): |
| snippet = snippet + "..." |
| example_contexts.append({ |
| "doc_id": ctx.chunk.doc_id, |
| "snippet": snippet, |
| }) |
|
|
| meanings.append({ |
| "cluster_id": cluster_id, |
| "occurrences": count, |
| "confidence": confidence, |
| "associated_words": associated_words, |
| "example_contexts": example_contexts, |
| }) |
|
|
| |
| meanings.sort(key=lambda m: -m["confidence"]) |
| meanings = meanings[:max_meanings] |
|
|
| return { |
| "keyword": keyword, |
| "total_occurrences": total, |
| "meanings": meanings, |
| } |
|
|
| |
| |
| |
|
|
| def _cluster_embeddings( |
| self, embeddings: np.ndarray, threshold: float = 0.35 |
| ) -> list[int]: |
| """Cluster embeddings using agglomerative clustering with cosine distance.""" |
| if len(embeddings) == 1: |
| return [0] |
|
|
| clustering = AgglomerativeClustering( |
| n_clusters=None, |
| distance_threshold=threshold, |
| metric="cosine", |
| linkage="average", |
| ) |
| labels = clustering.fit_predict(embeddings) |
| return labels.tolist() |
|
|
| def similar_words(self, word: str, top_k: int = 10) -> list[dict]: |
| """ |
| Find words that appear in similar contexts using transformer embeddings. |
| |
| Extracts unique words from the corpus, encodes them, and finds nearest |
| neighbors by cosine similarity. Unlike Word2Vec (one static vector per word), |
| this uses the transformer's contextual understanding. |
| |
| Args: |
| word: Target word. |
| top_k: Number of similar words to return. |
| |
| Returns: |
| List of {"word": str, "score": float} sorted by descending similarity. |
| """ |
| self._ensure_index() |
|
|
| word_pattern = re.compile(r"[a-zA-Z]{3,}") |
| word_lower = word.lower() |
|
|
| |
| vocab: set[str] = set() |
| for chunk in self.chunks: |
| for match in word_pattern.finditer(chunk.text): |
| w = match.group().lower() |
| if w != word_lower and w not in self._STOPWORDS: |
| vocab.add(w) |
|
|
| if not vocab: |
| return [] |
|
|
| vocab_list = sorted(vocab) |
| logger.info("Similar words: encoding %d vocabulary words for '%s'", len(vocab_list), word) |
|
|
| |
| all_texts = [word] + vocab_list |
| embeddings = self.model.encode( |
| all_texts, |
| batch_size=self.batch_size, |
| show_progress_bar=False, |
| convert_to_numpy=True, |
| normalize_embeddings=True, |
| ) |
|
|
| query_vec = embeddings[0:1] |
| vocab_vecs = embeddings[1:] |
|
|
| |
| scores = (vocab_vecs @ query_vec.T).flatten() |
| top_indices = np.argsort(scores)[::-1][:top_k] |
|
|
| return [ |
| {"word": vocab_list[i], "score": round(float(scores[i]), 4)} |
| for i in top_indices |
| ] |
|
|
| def _ensure_index(self): |
| if self.index is None: |
| raise RuntimeError("Index not built. Call build_index() first.") |
|
|
| def get_stats(self) -> dict: |
| """Return corpus statistics.""" |
| return { |
| "total_chunks": len(self.chunks), |
| "total_documents": len(self._doc_ids), |
| "document_ids": sorted(self._doc_ids), |
| "index_built": self.index is not None, |
| "embedding_dim": self.embedding_dim, |
| "model_name": self._model_name, |
| } |
|
|
| |
| |
| |
|
|
| def save(self, directory: str) -> dict: |
| """ |
| Save the full engine state (chunks, embeddings, FAISS index) to disk. |
| |
| Args: |
| directory: Path to save directory (created if needed). |
| |
| Returns: |
| Stats dict with what was saved. |
| """ |
| import json, pickle |
|
|
| save_dir = Path(directory) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with open(save_dir / "chunks.pkl", "wb") as f: |
| pickle.dump(self.chunks, f) |
|
|
| |
| meta = { |
| "model_name": self._model_name, |
| "chunk_size": self.chunk_size, |
| "chunk_overlap": self.chunk_overlap, |
| "batch_size": self.batch_size, |
| "embedding_dim": self.embedding_dim, |
| "doc_ids": sorted(self._doc_ids), |
| } |
| with open(save_dir / "meta.json", "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| |
| saved_index = False |
| if self.embeddings is not None: |
| np.save(save_dir / "embeddings.npy", self.embeddings) |
| if self.index is not None: |
| faiss.write_index(self.index, str(save_dir / "index.faiss")) |
| saved_index = True |
|
|
| logger.info("Engine saved to %s: %d chunks, %d docs, index=%s", |
| directory, len(self.chunks), len(self._doc_ids), saved_index) |
| return { |
| "directory": str(save_dir), |
| "chunks": len(self.chunks), |
| "documents": len(self._doc_ids), |
| "index_saved": saved_index, |
| } |
|
|
| @classmethod |
| def load(cls, directory: str, device: Optional[str] = None) -> "ContextualSimilarityEngine": |
| """ |
| Load a previously saved engine state from disk. |
| |
| Args: |
| directory: Path to the saved state directory. |
| device: PyTorch device override. |
| |
| Returns: |
| A fully restored ContextualSimilarityEngine instance. |
| """ |
| import json, pickle |
|
|
| save_dir = Path(directory) |
| if not save_dir.is_dir(): |
| raise FileNotFoundError(f"No saved state at {directory}") |
|
|
| |
| with open(save_dir / "meta.json") as f: |
| meta = json.load(f) |
|
|
| |
| engine = cls( |
| model_name=meta["model_name"], |
| chunk_size=meta["chunk_size"], |
| chunk_overlap=meta["chunk_overlap"], |
| device=device, |
| batch_size=meta["batch_size"], |
| ) |
|
|
| |
| with open(save_dir / "chunks.pkl", "rb") as f: |
| engine.chunks = pickle.load(f) |
| engine._doc_ids = set(meta["doc_ids"]) |
|
|
| |
| emb_path = save_dir / "embeddings.npy" |
| idx_path = save_dir / "index.faiss" |
| if emb_path.exists(): |
| engine.embeddings = np.load(emb_path) |
| if idx_path.exists(): |
| engine.index = faiss.read_index(str(idx_path)) |
|
|
| logger.info("Engine loaded from %s: %d chunks, %d docs, index=%s", |
| directory, len(engine.chunks), len(engine._doc_ids), engine.index is not None) |
| return engine |
|
|