| """ |
| Training Pipeline |
| |
| Three strategies, from simplest to most powerful: |
| |
| 1. Unsupervised — soft-label domain adaptation (CosineSimilarityLoss) |
| 2. Contrastive — adjacent sentences as positives, in-batch negatives (MNRL) |
| 3. Keyword-supervised — keyword→meaning pairs + MNRL |
| |
| All three produce a saved model you can load into ContextualSimilarityEngine. |
| |
| Usage: |
| trainer = CorpusTrainer(corpus_texts=[...]) |
| trainer.train_unsupervised("./my_model") |
| |
| trainer.train_with_keywords( |
| keyword_meanings={"pizza": "school"}, |
| output_path="./my_model", |
| ) |
| """ |
|
|
| import logging |
| import random |
| import re |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation, util |
| from torch.utils.data import DataLoader |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _BASE_DIR = Path(__file__).parent.resolve() |
|
|
|
|
| def _validate_output_path(path_str: str) -> str: |
| """Ensure output path is within the project directory.""" |
| resolved = Path(path_str).resolve() |
| if not resolved.is_relative_to(_BASE_DIR): |
| raise ValueError("Output path must be within the project directory.") |
| return path_str |
|
|
|
|
| class CorpusTrainer: |
| """ |
| Trains/fine-tunes a SentenceTransformer on your corpus. |
| |
| Extracts sentences from your documents on init. Then call one of: |
| - train_unsupervised() — soft-label pairs via current model similarity |
| - train_contrastive() — adjacent sentences as positives (MNRL) |
| - train_with_keywords() — keyword→meaning supervised pairs |
| """ |
|
|
| def __init__( |
| self, |
| corpus_texts: list[str], |
| base_model: str = "all-MiniLM-L6-v2", |
| seed: int = 42, |
| ): |
| self.base_model_name = base_model |
| self.model = SentenceTransformer(base_model) |
| self.rng = random.Random(seed) |
| np.random.seed(seed) |
|
|
| self.sentences = self._extract_sentences(corpus_texts) |
| self.rng.shuffle(self.sentences) |
| self._corpus_texts = corpus_texts |
| logger.info(f"Corpus: {len(self.sentences)} sentences from {len(corpus_texts)} documents") |
|
|
| |
| |
| |
|
|
| def train_unsupervised( |
| self, |
| output_path: str = "./trained_model", |
| epochs: int = 3, |
| batch_size: int = 16, |
| ) -> dict: |
| """ |
| Soft-label domain adaptation using CosineSimilarityLoss. |
| Samples random sentence pairs and uses the model's own similarity |
| scores as training labels — nudging the model toward the corpus |
| distribution without any manual labels. |
| """ |
| _validate_output_path(output_path) |
| t0 = time.time() |
|
|
| n = min(5000, len(self.sentences) * 2) |
| pairs = [] |
| for _ in range(n): |
| a, b = self.rng.sample(self.sentences, 2) |
| vecs = self.model.encode([a, b], normalize_embeddings=True, convert_to_tensor=True) |
| score = float(util.pytorch_cos_sim(vecs[0], vecs[1]).item()) |
| pairs.append(InputExample(texts=[a, b], label=score)) |
|
|
| if not pairs: |
| raise RuntimeError("Not enough sentences to build training pairs.") |
|
|
| loader = DataLoader(pairs, shuffle=True, batch_size=batch_size) |
| train_loss = losses.CosineSimilarityLoss(self.model) |
|
|
| logger.info(f"Unsupervised training: {len(pairs)} pairs, {epochs} epochs") |
| self.model.fit( |
| train_objectives=[(loader, train_loss)], |
| epochs=epochs, |
| show_progress_bar=True, |
| ) |
| self.model.save(output_path) |
|
|
| return self._report("unsupervised", output_path, len(pairs), epochs, time.time() - t0) |
|
|
| |
| |
| |
|
|
| def train_contrastive( |
| self, |
| output_path: str = "./trained_model", |
| epochs: int = 5, |
| batch_size: int = 16, |
| ) -> dict: |
| """ |
| Uses document structure: adjacent sentences become positive pairs. |
| MultipleNegativesRankingLoss provides in-batch negatives automatically. |
| Includes validation and saves the best checkpoint. |
| """ |
| _validate_output_path(output_path) |
| t0 = time.time() |
|
|
| positives = [] |
| for text in self._corpus_texts: |
| sents = self._extract_sentences([text]) |
| for i in range(len(sents) - 1): |
| positives.append(InputExample(texts=[sents[i], sents[i + 1]])) |
|
|
| if not positives: |
| raise RuntimeError("Not enough sentences to build training pairs.") |
|
|
| loader = DataLoader(positives, shuffle=True, batch_size=batch_size) |
| train_loss = losses.MultipleNegativesRankingLoss(self.model) |
| val_eval = self._build_evaluator() |
|
|
| logger.info(f"Contrastive training: {len(positives)} pairs, {epochs} epochs") |
| self.model.fit( |
| train_objectives=[(loader, train_loss)], |
| evaluator=val_eval, |
| epochs=epochs, |
| evaluation_steps=max(1, len(loader) // 2), |
| output_path=output_path, |
| save_best_model=True, |
| show_progress_bar=True, |
| ) |
|
|
| return self._report("contrastive", output_path, len(positives), epochs, time.time() - t0) |
|
|
| |
| |
| |
|
|
| def train_with_keywords( |
| self, |
| keyword_meanings: dict[str, str], |
| output_path: str = "./trained_model", |
| epochs: int = 5, |
| batch_size: int = 16, |
| context_window: int = 2, |
| ) -> dict: |
| """ |
| You provide keyword→meaning mappings (e.g. {"pizza": "school"}). |
| The trainer: |
| 1. Finds every sentence containing each keyword |
| 2. Builds positive pairs: keyword-in-context ↔ meaning-replaced version |
| 3. Uses MNRL (in-batch negatives handle the rest) |
| """ |
| _validate_output_path(output_path) |
| t0 = time.time() |
|
|
| doc_sentences = [self._extract_sentences([t]) for t in self._corpus_texts] |
|
|
| positives = [] |
| for keyword, meaning in keyword_meanings.items(): |
| pattern = re.compile(r"\b" + re.escape(keyword) + r"\b", re.IGNORECASE) |
|
|
| for sents in doc_sentences: |
| for i, sent in enumerate(sents): |
| if not pattern.search(sent): |
| continue |
|
|
| start = max(0, i - context_window) |
| end = min(len(sents), i + context_window + 1) |
| context = " ".join(sents[start:end]) |
|
|
| |
| replaced = pattern.sub(meaning, context) |
| positives.append(InputExample(texts=[context, replaced])) |
|
|
| |
| positives.append(InputExample(texts=[context, f"This is about {meaning}."])) |
|
|
| if not positives: |
| raise RuntimeError( |
| f"No keyword occurrences found in corpus. " |
| f"Keywords searched: {list(keyword_meanings.keys())}" |
| ) |
|
|
| self.rng.shuffle(positives) |
| loader = DataLoader(positives, shuffle=True, batch_size=batch_size) |
| train_loss = losses.MultipleNegativesRankingLoss(self.model) |
| val_eval = self._build_evaluator() |
|
|
| logger.info(f"Keyword training: {len(positives)} pairs, {epochs} epochs, " |
| f"keywords: {list(keyword_meanings.keys())}") |
| self.model.fit( |
| train_objectives=[(loader, train_loss)], |
| evaluator=val_eval, |
| epochs=epochs, |
| evaluation_steps=max(1, len(loader) // 2), |
| output_path=output_path, |
| save_best_model=True, |
| show_progress_bar=True, |
| ) |
|
|
| return self._report("keyword_supervised", output_path, len(positives), epochs, time.time() - t0, |
| extra={"keywords": list(keyword_meanings.keys())}) |
|
|
| |
| |
| |
|
|
| def evaluate_model( |
| self, |
| test_pairs: list[tuple[str, str, float]], |
| trained_model_path: str, |
| ) -> dict: |
| """ |
| Compare base model vs trained model on test pairs. |
| |
| Args: |
| test_pairs: List of (text_a, text_b, expected_similarity). |
| trained_model_path: Path to the trained model. |
| |
| Returns: |
| Dict with per-pair and summary comparison. |
| """ |
| _validate_output_path(trained_model_path) |
| base = SentenceTransformer(self.base_model_name) |
| trained = SentenceTransformer(trained_model_path) |
|
|
| results = [] |
| for text_a, text_b, expected in test_pairs: |
| base_sim = self._compute_sim(base, text_a, text_b) |
| trained_sim = self._compute_sim(trained, text_a, text_b) |
| results.append({ |
| "text_a": text_a[:100], |
| "text_b": text_b[:100], |
| "expected": expected, |
| "base_score": round(base_sim, 4), |
| "trained_score": round(trained_sim, 4), |
| "base_error": round(abs(base_sim - expected), 4), |
| "trained_error": round(abs(trained_sim - expected), 4), |
| }) |
|
|
| base_errors = [r["base_error"] for r in results] |
| trained_errors = [r["trained_error"] for r in results] |
| avg_base = np.mean(base_errors) |
| avg_trained = np.mean(trained_errors) |
|
|
| return { |
| "pairs": results, |
| "summary": { |
| "avg_base_error": round(float(avg_base), 4), |
| "avg_trained_error": round(float(avg_trained), 4), |
| "error_reduction_pct": round( |
| ((avg_base - avg_trained) / avg_base * 100) if avg_base > 0 else 0, 1 |
| ), |
| "improved": sum(1 for r in results if r["trained_error"] < r["base_error"]), |
| "degraded": sum(1 for r in results if r["trained_error"] > r["base_error"]), |
| "total": len(results), |
| }, |
| } |
|
|
| |
| |
| |
|
|
| def _build_evaluator(self): |
| """Build a validation evaluator from random sentence pairs.""" |
| n = min(100, len(self.sentences) // 2) |
| if n < 10: |
| return None |
|
|
| s1, s2, scores = [], [], [] |
| sampled = self.rng.sample(range(len(self.sentences)), min(n * 2, len(self.sentences))) |
| for i in range(0, len(sampled) - 1, 2): |
| a_idx, b_idx = sampled[i], sampled[i + 1] |
| s1.append(self.sentences[a_idx]) |
| s2.append(self.sentences[b_idx]) |
| vecs = self.model.encode([self.sentences[a_idx], self.sentences[b_idx]], |
| normalize_embeddings=True, convert_to_tensor=True) |
| scores.append(float(util.pytorch_cos_sim(vecs[0], vecs[1]).item())) |
|
|
| return evaluation.EmbeddingSimilarityEvaluator(s1, s2, scores, name="val", show_progress_bar=False) |
|
|
| @staticmethod |
| def _compute_sim(model: SentenceTransformer, a: str, b: str) -> float: |
| vecs = model.encode([a, b], normalize_embeddings=True, convert_to_tensor=True) |
| return float(util.pytorch_cos_sim(vecs[0], vecs[1]).item()) |
|
|
| @staticmethod |
| def _extract_sentences(texts: list[str]) -> list[str]: |
| sentences = [] |
| for text in texts: |
| parts = re.split(r"(?<=[.!?])\s+", text.strip()) |
| for s in parts: |
| s = s.strip() |
| if len(s.split()) >= 5: |
| sentences.append(s) |
| return sentences |
|
|
| @staticmethod |
| def _report(strategy, path, pairs, epochs, elapsed, extra=None): |
| report = { |
| "strategy": strategy, |
| "model_path": path, |
| "training_pairs": pairs, |
| "epochs": epochs, |
| "seconds": round(elapsed, 2), |
| } |
| if extra: |
| report.update(extra) |
| logger.info(f"Training complete ({strategy}): {pairs} pairs, {elapsed:.1f}s -> {path}") |
| return report |
|
|