esfiles / training.py
Besjon Cifliku
feat: initial project setup
db764ae
"""
Training Pipeline
Three strategies, from simplest to most powerful:
1. Unsupervised — soft-label domain adaptation (CosineSimilarityLoss)
2. Contrastive — adjacent sentences as positives, in-batch negatives (MNRL)
3. Keyword-supervised — keyword→meaning pairs + MNRL
All three produce a saved model you can load into ContextualSimilarityEngine.
Usage:
trainer = CorpusTrainer(corpus_texts=[...])
trainer.train_unsupervised("./my_model")
trainer.train_with_keywords(
keyword_meanings={"pizza": "school"},
output_path="./my_model",
)
"""
import logging
import random
import re
import time
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation, util
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
_BASE_DIR = Path(__file__).parent.resolve()
def _validate_output_path(path_str: str) -> str:
"""Ensure output path is within the project directory."""
resolved = Path(path_str).resolve()
if not resolved.is_relative_to(_BASE_DIR):
raise ValueError("Output path must be within the project directory.")
return path_str
class CorpusTrainer:
"""
Trains/fine-tunes a SentenceTransformer on your corpus.
Extracts sentences from your documents on init. Then call one of:
- train_unsupervised() — soft-label pairs via current model similarity
- train_contrastive() — adjacent sentences as positives (MNRL)
- train_with_keywords() — keyword→meaning supervised pairs
"""
def __init__(
self,
corpus_texts: list[str],
base_model: str = "all-MiniLM-L6-v2",
seed: int = 42,
):
self.base_model_name = base_model
self.model = SentenceTransformer(base_model)
self.rng = random.Random(seed)
np.random.seed(seed)
self.sentences = self._extract_sentences(corpus_texts)
self.rng.shuffle(self.sentences)
self._corpus_texts = corpus_texts
logger.info(f"Corpus: {len(self.sentences)} sentences from {len(corpus_texts)} documents")
# ------------------------------------------------------------------ #
# Strategy 1: Unsupervised (soft-label domain adaptation)
# ------------------------------------------------------------------ #
def train_unsupervised(
self,
output_path: str = "./trained_model",
epochs: int = 3,
batch_size: int = 16,
) -> dict:
"""
Soft-label domain adaptation using CosineSimilarityLoss.
Samples random sentence pairs and uses the model's own similarity
scores as training labels — nudging the model toward the corpus
distribution without any manual labels.
"""
_validate_output_path(output_path)
t0 = time.time()
n = min(5000, len(self.sentences) * 2)
pairs = []
for _ in range(n):
a, b = self.rng.sample(self.sentences, 2)
vecs = self.model.encode([a, b], normalize_embeddings=True, convert_to_tensor=True)
score = float(util.pytorch_cos_sim(vecs[0], vecs[1]).item())
pairs.append(InputExample(texts=[a, b], label=score))
if not pairs:
raise RuntimeError("Not enough sentences to build training pairs.")
loader = DataLoader(pairs, shuffle=True, batch_size=batch_size)
train_loss = losses.CosineSimilarityLoss(self.model)
logger.info(f"Unsupervised training: {len(pairs)} pairs, {epochs} epochs")
self.model.fit(
train_objectives=[(loader, train_loss)],
epochs=epochs,
show_progress_bar=True,
)
self.model.save(output_path)
return self._report("unsupervised", output_path, len(pairs), epochs, time.time() - t0)
# ------------------------------------------------------------------ #
# Strategy 2: Contrastive (structural pairs + MNRL)
# ------------------------------------------------------------------ #
def train_contrastive(
self,
output_path: str = "./trained_model",
epochs: int = 5,
batch_size: int = 16,
) -> dict:
"""
Uses document structure: adjacent sentences become positive pairs.
MultipleNegativesRankingLoss provides in-batch negatives automatically.
Includes validation and saves the best checkpoint.
"""
_validate_output_path(output_path)
t0 = time.time()
positives = []
for text in self._corpus_texts:
sents = self._extract_sentences([text])
for i in range(len(sents) - 1):
positives.append(InputExample(texts=[sents[i], sents[i + 1]]))
if not positives:
raise RuntimeError("Not enough sentences to build training pairs.")
loader = DataLoader(positives, shuffle=True, batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(self.model)
val_eval = self._build_evaluator()
logger.info(f"Contrastive training: {len(positives)} pairs, {epochs} epochs")
self.model.fit(
train_objectives=[(loader, train_loss)],
evaluator=val_eval,
epochs=epochs,
evaluation_steps=max(1, len(loader) // 2),
output_path=output_path,
save_best_model=True,
show_progress_bar=True,
)
return self._report("contrastive", output_path, len(positives), epochs, time.time() - t0)
# ------------------------------------------------------------------ #
# Strategy 3: Keyword-supervised
# ------------------------------------------------------------------ #
def train_with_keywords(
self,
keyword_meanings: dict[str, str],
output_path: str = "./trained_model",
epochs: int = 5,
batch_size: int = 16,
context_window: int = 2,
) -> dict:
"""
You provide keyword→meaning mappings (e.g. {"pizza": "school"}).
The trainer:
1. Finds every sentence containing each keyword
2. Builds positive pairs: keyword-in-context ↔ meaning-replaced version
3. Uses MNRL (in-batch negatives handle the rest)
"""
_validate_output_path(output_path)
t0 = time.time()
doc_sentences = [self._extract_sentences([t]) for t in self._corpus_texts]
positives = []
for keyword, meaning in keyword_meanings.items():
pattern = re.compile(r"\b" + re.escape(keyword) + r"\b", re.IGNORECASE)
for sents in doc_sentences:
for i, sent in enumerate(sents):
if not pattern.search(sent):
continue
start = max(0, i - context_window)
end = min(len(sents), i + context_window + 1)
context = " ".join(sents[start:end])
# Positive: context with keyword → same context with meaning substituted
replaced = pattern.sub(meaning, context)
positives.append(InputExample(texts=[context, replaced]))
# Positive: context with keyword → meaning description
positives.append(InputExample(texts=[context, f"This is about {meaning}."]))
if not positives:
raise RuntimeError(
f"No keyword occurrences found in corpus. "
f"Keywords searched: {list(keyword_meanings.keys())}"
)
self.rng.shuffle(positives)
loader = DataLoader(positives, shuffle=True, batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(self.model)
val_eval = self._build_evaluator()
logger.info(f"Keyword training: {len(positives)} pairs, {epochs} epochs, "
f"keywords: {list(keyword_meanings.keys())}")
self.model.fit(
train_objectives=[(loader, train_loss)],
evaluator=val_eval,
epochs=epochs,
evaluation_steps=max(1, len(loader) // 2),
output_path=output_path,
save_best_model=True,
show_progress_bar=True,
)
return self._report("keyword_supervised", output_path, len(positives), epochs, time.time() - t0,
extra={"keywords": list(keyword_meanings.keys())})
# ------------------------------------------------------------------ #
# Compare base vs trained
# ------------------------------------------------------------------ #
def evaluate_model(
self,
test_pairs: list[tuple[str, str, float]],
trained_model_path: str,
) -> dict:
"""
Compare base model vs trained model on test pairs.
Args:
test_pairs: List of (text_a, text_b, expected_similarity).
trained_model_path: Path to the trained model.
Returns:
Dict with per-pair and summary comparison.
"""
_validate_output_path(trained_model_path)
base = SentenceTransformer(self.base_model_name)
trained = SentenceTransformer(trained_model_path)
results = []
for text_a, text_b, expected in test_pairs:
base_sim = self._compute_sim(base, text_a, text_b)
trained_sim = self._compute_sim(trained, text_a, text_b)
results.append({
"text_a": text_a[:100],
"text_b": text_b[:100],
"expected": expected,
"base_score": round(base_sim, 4),
"trained_score": round(trained_sim, 4),
"base_error": round(abs(base_sim - expected), 4),
"trained_error": round(abs(trained_sim - expected), 4),
})
base_errors = [r["base_error"] for r in results]
trained_errors = [r["trained_error"] for r in results]
avg_base = np.mean(base_errors)
avg_trained = np.mean(trained_errors)
return {
"pairs": results,
"summary": {
"avg_base_error": round(float(avg_base), 4),
"avg_trained_error": round(float(avg_trained), 4),
"error_reduction_pct": round(
((avg_base - avg_trained) / avg_base * 100) if avg_base > 0 else 0, 1
),
"improved": sum(1 for r in results if r["trained_error"] < r["base_error"]),
"degraded": sum(1 for r in results if r["trained_error"] > r["base_error"]),
"total": len(results),
},
}
# ------------------------------------------------------------------ #
# Internals
# ------------------------------------------------------------------ #
def _build_evaluator(self):
"""Build a validation evaluator from random sentence pairs."""
n = min(100, len(self.sentences) // 2)
if n < 10:
return None
s1, s2, scores = [], [], []
sampled = self.rng.sample(range(len(self.sentences)), min(n * 2, len(self.sentences)))
for i in range(0, len(sampled) - 1, 2):
a_idx, b_idx = sampled[i], sampled[i + 1]
s1.append(self.sentences[a_idx])
s2.append(self.sentences[b_idx])
vecs = self.model.encode([self.sentences[a_idx], self.sentences[b_idx]],
normalize_embeddings=True, convert_to_tensor=True)
scores.append(float(util.pytorch_cos_sim(vecs[0], vecs[1]).item()))
return evaluation.EmbeddingSimilarityEvaluator(s1, s2, scores, name="val", show_progress_bar=False)
@staticmethod
def _compute_sim(model: SentenceTransformer, a: str, b: str) -> float:
vecs = model.encode([a, b], normalize_embeddings=True, convert_to_tensor=True)
return float(util.pytorch_cos_sim(vecs[0], vecs[1]).item())
@staticmethod
def _extract_sentences(texts: list[str]) -> list[str]:
sentences = []
for text in texts:
parts = re.split(r"(?<=[.!?])\s+", text.strip())
for s in parts:
s = s.strip()
if len(s.split()) >= 5:
sentences.append(s)
return sentences
@staticmethod
def _report(strategy, path, pairs, epochs, elapsed, extra=None):
report = {
"strategy": strategy,
"model_path": path,
"training_pairs": pairs,
"epochs": epochs,
"seconds": round(elapsed, 2),
}
if extra:
report.update(extra)
logger.info(f"Training complete ({strategy}): {pairs} pairs, {elapsed:.1f}s -> {path}")
return report