diff --git a/scripts/__pycache__/evaluate.cpython-312.pyc b/scripts/__pycache__/evaluate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e37015780682a5620e73d4c2444b7be6a43446c6 Binary files /dev/null and b/scripts/__pycache__/evaluate.cpython-312.pyc differ diff --git a/src/__pycache__/__init__.cpython-312.pyc b/src/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b00b59efed738237b221b46f523925674c715ea0 Binary files /dev/null and b/src/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/__pycache__/__init__.cpython-314.pyc b/src/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3780f12d3d970ae043e50496351ee141016693f1 Binary files /dev/null and b/src/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/api/__pycache__/main.cpython-312.pyc b/src/api/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..719793ff64b6a2c77852c65c47e545e8fca54c4e Binary files /dev/null and b/src/api/__pycache__/main.cpython-312.pyc differ diff --git a/src/api/__pycache__/middleware.cpython-312.pyc b/src/api/__pycache__/middleware.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..211c1bd85e9416151a86ee2bb8f208f7cd844cdf Binary files /dev/null and b/src/api/__pycache__/middleware.cpython-312.pyc differ diff --git a/src/api/__pycache__/schemas.cpython-312.pyc b/src/api/__pycache__/schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a845334fddc2b28f3f15f05222e50d81bc351f1 Binary files /dev/null and b/src/api/__pycache__/schemas.cpython-312.pyc differ diff --git a/src/api/middleware.py b/src/api/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..3d304b48e9d8fc3f9a7fbf4a84dbeb1660f79b2c --- /dev/null +++ b/src/api/middleware.py @@ -0,0 +1,67 @@ +""" +API middleware for request logging, rate limiting, and error handling. +""" + +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from loguru import logger +import time +from collections import defaultdict, deque + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """Logs all incoming requests with timing information.""" + + async def dispatch(self, request: Request, call_next): + start_time = time.time() + path = request.url.path + method = request.method + + logger.info(f"→ {method} {path}") + + try: + response = await call_next(request) + except Exception as e: + logger.error(f"✗ {method} {path} - Error: {e}") + raise + + elapsed = (time.time() - start_time) * 1000 # ms + logger.info(f"← {method} {path} - {response.status_code} ({elapsed:.1f}ms)") + + return response + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Simple in-memory rate limiting.""" + + def __init__(self, app, max_requests_per_minute: int = 60): + super().__init__(app) + self.max_requests = max_requests_per_minute + self.window = 60 # seconds + # Track requests per client IP: {ip: deque([timestamp, ...])} + self.requests: dict = defaultdict(deque) + + async def dispatch(self, request: Request, call_next): + # Get client IP + client_ip = request.client.host if request.client else "unknown" + now = time.time() + + # Clean old entries + timestamps = self.requests[client_ip] + while timestamps and timestamps[0] < now - self.window: + timestamps.popleft() + + # Check rate limit + if len(timestamps) >= self.max_requests: + logger.warning(f"Rate limited: {client_ip} ({len(timestamps)} requests in {self.window}s)") + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded. Please wait before making more requests."}, + ) + + # Record this request + timestamps.append(now) + + response = await call_next(request) + return response diff --git a/src/api/schemas.py b/src/api/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..11ee53ce86f53bac062766cac4301db5f8a20493 --- /dev/null +++ b/src/api/schemas.py @@ -0,0 +1,21 @@ +""" +Pydantic schemas for API request/response validation. +""" + +from pydantic import BaseModel, Field +from typing import Optional, Dict + + +class CorrectionRequest(BaseModel): + text: str = Field(..., min_length=10, max_length=5000, description="Raw dyslectic text to correct.") + master_copy: Optional[str] = Field(None, description="Optional master copy to match style toward.") + style_alpha: float = Field(0.6, ge=0.0, le=1.0, description="Weight given to user's own style (0=full master, 1=full user).") + + +class CorrectionResponse(BaseModel): + original: str + corrected: str + style_similarity: float + awl_coverage: float + readability: Dict[str, float] + changes_summary: str diff --git a/src/evaluation/__init__.py b/src/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/evaluation/__pycache__/__init__.cpython-314.pyc b/src/evaluation/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2591f86bc8e3436177df5a59fce88f835753e0b9 Binary files /dev/null and b/src/evaluation/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc b/src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9789c7ac6d22975339d17c8c17d67946c2b168c8 Binary files /dev/null and b/src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc differ diff --git a/src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc b/src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963a45589c359c0d0deb2a04a2e760bb4d205007 Binary files /dev/null and b/src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc differ diff --git a/src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc b/src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a44227842560518f32376fe638a63ee73471dec9 Binary files /dev/null and b/src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc differ diff --git a/src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc b/src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709c6c64e38447bebfb2c5f3b34c6fd0f892cd7e Binary files /dev/null and b/src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc differ diff --git a/src/evaluation/__pycache__/style_metrics.cpython-312.pyc b/src/evaluation/__pycache__/style_metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cb7503604d52df6d9a85e19daa7253afd42e907 Binary files /dev/null and b/src/evaluation/__pycache__/style_metrics.cpython-312.pyc differ diff --git a/src/evaluation/__pycache__/style_metrics.cpython-314.pyc b/src/evaluation/__pycache__/style_metrics.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5638dd1563801f58cb4e24d651fd8253a244f5f1 Binary files /dev/null and b/src/evaluation/__pycache__/style_metrics.cpython-314.pyc differ diff --git a/src/evaluation/authorship_verifier.py b/src/evaluation/authorship_verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..c98ae7cb0b5c610c0fe31344c66cc27b1a420681 --- /dev/null +++ b/src/evaluation/authorship_verifier.py @@ -0,0 +1,50 @@ +""" +Authorship verification module. +Uses a fine-tuned model to verify whether the corrected output +could plausibly have been written by the same author as the input. +Target: > 0.80 same-author probability. +""" + +from typing import Tuple +from loguru import logger +import torch +import torch.nn.functional as F + + +class AuthorshipVerifier: + """Verifies authorship consistency between input and output text.""" + + def __init__(self, model_name: str = "roberta-base"): + try: + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer(model_name, device="cpu") + logger.info(f"AuthorshipVerifier loaded with {model_name}") + except Exception as e: + logger.warning(f"Failed to load authorship model: {e}") + self.model = None + + def verify(self, text_a: str, text_b: str) -> float: + """Return probability that both texts were written by the same author. + + Uses sentence embedding similarity as a proxy for authorship. + Higher cosine similarity suggests same author. + """ + if self.model is None: + return 0.5 # Neutral score if model unavailable + + if not text_a or not text_b: + return 0.5 + + try: + embeddings = self.model.encode([text_a, text_b], convert_to_tensor=True) + sim = F.cosine_similarity( + embeddings[0].unsqueeze(0), + embeddings[1].unsqueeze(0), + ) + # Scale similarity to [0, 1] probability + # Cosine similarity is already in [-1, 1], shift to [0, 1] + prob = (sim.item() + 1.0) / 2.0 + return prob + except Exception as e: + logger.warning(f"Authorship verification failed: {e}") + return 0.5 diff --git a/src/evaluation/errant_evaluator.py b/src/evaluation/errant_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..56d81f484591079b280bd9a0f2fe9f95d8cfe648 --- /dev/null +++ b/src/evaluation/errant_evaluator.py @@ -0,0 +1,82 @@ +""" +ERRANT-based grammatical error evaluation. +Uses the ERRANT toolkit for standardised GEC evaluation with +precision, recall, and F0.5 scores. +""" + +from typing import List, Dict +from loguru import logger + + +class ERRANTEvaluator: + """Evaluates grammar correction quality using ERRANT annotations.""" + + def __init__(self): + try: + import errant + self.annotator = errant.load("en") + logger.info("ERRANT annotator loaded") + except Exception as e: + logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.") + self.annotator = None + + def evaluate( + self, + sources: List[str], + predictions: List[str], + references: List[str], + ) -> Dict[str, float]: + """Compute ERRANT precision, recall, F0.5.""" + if self.annotator is None: + logger.warning("ERRANT not available, returning zero scores") + return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0} + + tp = 0 + fp = 0 + fn = 0 + + for src, pred, ref in zip(sources, predictions, references): + try: + # Parse source and annotate edits + orig = self.annotator.parse(src) + cor_pred = self.annotator.parse(pred) + cor_ref = self.annotator.parse(ref) + + # Get edit annotations + pred_edits = self.annotator.annotate(orig, cor_pred) + ref_edits = self.annotator.annotate(orig, cor_ref) + + # Convert to comparable sets of (start, end, correction, type) + pred_set = set() + for e in pred_edits: + pred_set.add((e.o_start, e.o_end, e.c_str, e.type)) + + ref_set = set() + for e in ref_edits: + ref_set.add((e.o_start, e.o_end, e.c_str, e.type)) + + # Count TP, FP, FN + tp += len(pred_set & ref_set) + fp += len(pred_set - ref_set) + fn += len(ref_set - pred_set) + + except Exception as e: + logger.debug(f"ERRANT annotation failed for a sample: {e}") + continue + + # Compute precision, recall, F0.5 + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + + # F0.5 weighs precision higher than recall (β=0.5) + beta = 0.5 + if precision + recall > 0: + f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) + else: + f_score = 0.0 + + return { + "precision": precision, + "recall": recall, + "f0.5": f_score, + } diff --git a/src/evaluation/gleu_scorer.py b/src/evaluation/gleu_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1427dab91aba20c194f8cea3e5260f10b58f30 --- /dev/null +++ b/src/evaluation/gleu_scorer.py @@ -0,0 +1,68 @@ +""" +GLEU (Generalized Language Evaluation Understanding) score. +Preferred over BLEU for grammatical error correction tasks. +Also computes BERTScore for semantic similarity evaluation. +""" + +import sacrebleu +from bert_score import score as bert_score_fn +from typing import List, Tuple +from loguru import logger + + +class GLEUScorer: + """Computes GLEU and BERTScore metrics for GEC evaluation.""" + + def compute_gleu( + self, + predictions: List[str], + references: List[str], + ) -> float: + """Corpus-level GLEU score (0-100). + + GLEU is the geometric mean of n-gram precisions and recall, + preferred over BLEU for GEC because it equally penalises + both under-correction and over-correction. + """ + if not predictions or not references: + return 0.0 + + # sacrebleu expects references as a list of lists + refs = [references] + + # Use BLEU with smoothing as GLEU approximation + # sacrebleu doesn't have a native GLEU, so we use smoothed BLEU + bleu = sacrebleu.corpus_bleu( + predictions, + refs, + smooth_method="exp", + smooth_value=0.1, + ) + return bleu.score + + def compute_bert_score( + self, + predictions: List[str], + references: List[str], + lang: str = "en", + ) -> Tuple[float, float, float]: + """Returns (precision, recall, F1) as averages over the batch.""" + if not predictions or not references: + return (0.0, 0.0, 0.0) + + try: + P, R, F1 = bert_score_fn( + predictions, + references, + lang=lang, + verbose=False, + device="cpu", # CPU-optimised + ) + return ( + P.mean().item(), + R.mean().item(), + F1.mean().item(), + ) + except Exception as e: + logger.warning(f"BERTScore computation failed: {e}") + return (0.0, 0.0, 0.0) diff --git a/src/evaluation/style_metrics.py b/src/evaluation/style_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..da8d93ac66815bc24e33dbc3a593da060d9325e7 --- /dev/null +++ b/src/evaluation/style_metrics.py @@ -0,0 +1,81 @@ +""" +Measures style preservation between input and output. + +Key metrics: + - Style Vector Cosine Similarity (target: > 0.85) + - AWL Coverage Score (target: > 0.25) + - Authorship Verification Score (target: > 0.80) +""" + +import torch +import torch.nn.functional as F +from typing import List, Tuple +from ..style.fingerprinter import StyleFingerprinter +from ..vocabulary.awl_loader import AWLLoader +from loguru import logger +import numpy as np + + +class StyleEvaluator: + """Evaluates style preservation and academic vocabulary coverage.""" + + def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader): + self.fingerprinter = fingerprinter + self.awl = awl + + def style_similarity(self, text_a: str, text_b: str) -> float: + """Cosine similarity between style vectors. Target: > 0.85.""" + vec_a = self.fingerprinter.extract_vector(text_a) + vec_b = self.fingerprinter.extract_vector(text_b) + + if vec_a.dim() == 1: + vec_a = vec_a.unsqueeze(0) + if vec_b.dim() == 1: + vec_b = vec_b.unsqueeze(0) + + sim = F.cosine_similarity(vec_a, vec_b, dim=-1) + return sim.item() + + def awl_coverage(self, text: str) -> float: + """Fraction of content words in AWL. Target: > 0.25.""" + if not text or not text.strip(): + return 0.0 + + words = text.lower().split() + # Filter to content words (longer than 3 chars, alphabetic) + content_words = [w for w in words if len(w) > 3 and w.isalpha()] + + if not content_words: + return 0.0 + + awl_count = sum(1 for w in content_words if self.awl.is_academic(w)) + return awl_count / len(content_words) + + def evaluate_batch( + self, + inputs: List[str], + outputs: List[str], + references: List[str], + ) -> dict: + """Compute style and AWL metrics for a batch.""" + style_sims = [] + awl_coverages = [] + ref_style_sims = [] + + for inp, out, ref in zip(inputs, outputs, references): + # Style similarity between input and output (preservation) + style_sims.append(self.style_similarity(inp, out)) + + # AWL coverage of output + awl_coverages.append(self.awl_coverage(out)) + + # Style similarity between output and reference + ref_style_sims.append(self.style_similarity(out, ref)) + + return { + "style_similarity_mean": float(np.mean(style_sims)), + "style_similarity_std": float(np.std(style_sims)), + "awl_coverage_mean": float(np.mean(awl_coverages)), + "awl_coverage_std": float(np.std(awl_coverages)), + "ref_style_similarity_mean": float(np.mean(ref_style_sims)), + } diff --git a/src/inference/__init__.py b/src/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/inference/__pycache__/__init__.cpython-314.pyc b/src/inference/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73e3358b4843f1035fd5c845d7b71cb11ba96616 Binary files /dev/null and b/src/inference/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/inference/__pycache__/corrector.cpython-312.pyc b/src/inference/__pycache__/corrector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62be411332f9058aa5e5a2210d7be1665077c2f1 Binary files /dev/null and b/src/inference/__pycache__/corrector.cpython-312.pyc differ diff --git a/src/inference/__pycache__/corrector.cpython-314.pyc b/src/inference/__pycache__/corrector.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c135b94fe4e85803fa057f2e1f11d12c7933998a Binary files /dev/null and b/src/inference/__pycache__/corrector.cpython-314.pyc differ diff --git a/src/inference/__pycache__/postprocessor.cpython-312.pyc b/src/inference/__pycache__/postprocessor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eda00d978c8fa2909027a82998f124e7468240a4 Binary files /dev/null and b/src/inference/__pycache__/postprocessor.cpython-312.pyc differ diff --git a/src/inference/__pycache__/postprocessor.cpython-314.pyc b/src/inference/__pycache__/postprocessor.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1009cd0d4515368596224ef621bdc749601e36f Binary files /dev/null and b/src/inference/__pycache__/postprocessor.cpython-314.pyc differ diff --git a/src/inference/corrector.py b/src/inference/corrector.py new file mode 100644 index 0000000000000000000000000000000000000000..aa277c53dc0cccf63e0e5b96efd5a91683f60f70 --- /dev/null +++ b/src/inference/corrector.py @@ -0,0 +1,283 @@ +""" +End-to-end inference pipeline. +Accepts raw dyslectic text (and optionally a master copy), +returns corrected academic text with metadata. +""" + +from ..preprocessing.pipeline import PreprocessingPipeline +from ..style.fingerprinter import StyleFingerprinter +from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter +from ..model.base_model import load_model_and_tokenizer +from ..model.style_conditioner import StyleConditioner, prepend_style_prefix +from ..model.generation_utils import generate_correction +from .postprocessor import PostProcessor +from ..evaluation.style_metrics import StyleEvaluator +from ..vocabulary.awl_loader import AWLLoader +import torch +from typing import Optional +from dataclasses import dataclass +from loguru import logger +import yaml + + +TASK_PREFIX = ( + "Correct the following text for grammar, spelling, and clarity. " + "Maintain the author's original tone and writing style. " + "Elevate vocabulary to academic register. " + "Do NOT change the meaning or add new information. " + "Preserve named entities exactly. " + "Text to correct: " +) + + +@dataclass +class CorrectionResult: + original: str + corrected: str + preprocessed: str + style_similarity: float + awl_coverage: float + readability: dict + changes_summary: str + + +class AcademicCorrector: + """Full inference pipeline: preprocess → fingerprint → generate → elevate → filter.""" + + def __init__(self, config: dict): + logger.info("Initialising AcademicCorrector...") + + model_cfg = config.get("model", {}) + gen_cfg = config.get("generation", {}) + vocab_cfg = config.get("vocabulary", {}) + style_cfg = config.get("style_conditioner", {}) + + # 1. Load model and tokenizer + model_key = model_cfg.get("key", "flan-t5-small") + checkpoint = model_cfg.get("checkpoint_path", None) + use_lora = model_cfg.get("use_lora", False) + + if checkpoint and use_lora: + # PEFT adapter checkpoint: load base model + apply adapter + import os + try: + from peft import PeftModel + logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'") + self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( + model_key, quantize=False, use_lora=False + ) + self.model = PeftModel.from_pretrained(self.model, checkpoint) + logger.info(f"PEFT adapter loaded from {checkpoint}") + except Exception as e: + logger.warning(f"PEFT loading failed ({e}), loading base model only") + self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( + model_key, quantize=False, use_lora=False + ) + elif checkpoint: + # Full model checkpoint (merged weights) + try: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) + self.tokenizer = AutoTokenizer.from_pretrained(checkpoint) + self.is_seq2seq = True + logger.info(f"Loaded full model from checkpoint: {checkpoint}") + except Exception: + logger.warning(f"Checkpoint not found, loading base model: {model_key}") + self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( + model_key, quantize=False, use_lora=False + ) + else: + self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( + model_key, quantize=False, use_lora=False + ) + + self.model.eval() + self.generation_config = gen_cfg + + # 2. Preprocessor + self.preprocessor = PreprocessingPipeline() + + # 3. Style fingerprinter + fp_cfg = config.get("fingerprinter", {}) + self.fingerprinter = StyleFingerprinter( + spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"), + awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), + ) + + # 4. Style conditioner — auto-detect hidden dim from loaded model + if hasattr(self.model.config, "d_model"): + auto_hidden_dim = self.model.config.d_model + elif hasattr(self.model.config, "hidden_size"): + auto_hidden_dim = self.model.config.hidden_size + else: + auto_hidden_dim = 512 # Safe default for T5-Small + logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}") + + self.conditioner = StyleConditioner( + style_dim=style_cfg.get("style_dim", 512), + model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim), + n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10), + ) + self.conditioner.eval() + + # 5. Vocabulary elevator + try: + self.elevator = LexicalElevator( + awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), + spacy_model="en_core_web_sm", + mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"), + sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"), + ) + except Exception as e: + logger.warning(f"Lexical elevator init failed: {e}, elevation disabled") + self.elevator = None + + # 6. Register filter + self.register_filter = RegisterFilter() + + # 7. Post-processor + self.postprocessor = PostProcessor() + + # 8. Evaluator + awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt")) + self.evaluator = StyleEvaluator(self.fingerprinter, awl) + + logger.info("AcademicCorrector initialised successfully") + + def correct( + self, + raw_text: str, + master_copy: Optional[str] = None, + style_alpha: float = 0.6, + ) -> CorrectionResult: + """ + Full correction pipeline: + 1. Pre-process (spell correct + parse) + 2. Style fingerprint + 3. Generate with style conditioning + 4. Academic vocabulary elevation + 5. Register filter + 6. Compute quality metrics + """ + # Step 1: Pre-process + logger.info("Step 1: Preprocessing...") + doc = self.preprocessor.process(raw_text) + + # Step 2: Style fingerprint + logger.info("Step 2: Extracting style fingerprint...") + user_style = self.fingerprinter.extract_vector(doc.corrected_text) + + if master_copy: + master_style = self.fingerprinter.extract_vector(master_copy) + target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha) + else: + target_style = user_style + + # Step 3: Generate correction (sentence-chunked) + # The model was trained on max_input_length=128 tokens. + # Split text into sentence groups that fit within that window, + # process each chunk, then reassemble. + logger.info("Step 3: Generating correction (chunked)...") + + MAX_INPUT_TOKENS = 128 + # Measure how many tokens the task prefix uses + prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False)) + budget = MAX_INPUT_TOKENS - prefix_tokens - 2 # 2 for special tokens + + # Split into sentences using spaCy (already loaded for fingerprinting) + sent_doc = self.fingerprinter.nlp(doc.corrected_text) + sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()] + + # Group sentences into chunks that fit the token budget + chunks = [] + current_chunk = [] + current_tokens = 0 + + for sent in sentences: + sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False)) + if current_tokens + sent_tokens > budget and current_chunk: + chunks.append(" ".join(current_chunk)) + current_chunk = [sent] + current_tokens = sent_tokens + else: + current_chunk.append(sent) + current_tokens += sent_tokens + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences") + + corrected_chunks = [] + device = next(self.model.parameters()).device + + for i, chunk in enumerate(chunks): + chunk_input = TASK_PREFIX + chunk + inputs = self.tokenizer( + chunk_input, + max_length=MAX_INPUT_TOKENS, + truncation=True, + return_tensors="pt", + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + chunk_output = generate_correction( + self.model, + self.tokenizer, + input_ids, + attention_mask, + self.generation_config, + ) + corrected_chunks.append(chunk_output) + logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())} → {len(chunk_output.split())} words") + + generated = " ".join(corrected_chunks) + + # Step 4: Post-process + logger.info("Step 4: Post-processing...") + generated = self.postprocessor.clean(generated) + generated = self.postprocessor.restore_entities( + generated, + [e.text for e in doc.entities], + doc.protected_spans, + ) + + # Step 5: Vocabulary elevation + logger.info("Step 5: Vocabulary elevation...") + if self.elevator: + try: + generated = self.elevator.elevate(generated, doc.protected_spans) + except Exception as e: + logger.warning(f"Vocabulary elevation failed: {e}") + + # Step 6: Register filter + logger.info("Step 6: Register filtering...") + generated = self.register_filter.apply(generated) + + # Final formatting + generated = self.postprocessor.format_output(generated) + + # Step 7: Compute quality metrics + logger.info("Step 7: Computing metrics...") + style_sim = self.evaluator.style_similarity(raw_text, generated) + awl_cov = self.evaluator.awl_coverage(generated) + + # Build changes summary + changes = [] + if doc.original_text != doc.corrected_text: + changes.append("Spelling/grammar corrections applied") + if generated != doc.corrected_text: + changes.append("Text restructured and elevated") + changes_summary = "; ".join(changes) if changes else "No changes needed" + + return CorrectionResult( + original=raw_text, + corrected=generated, + preprocessed=doc.corrected_text, + style_similarity=style_sim, + awl_coverage=awl_cov, + readability=doc.readability, + changes_summary=changes_summary, + ) diff --git a/src/inference/postprocessor.py b/src/inference/postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..fac475a9e13ff97384ad0e0ff0c667a203510e10 --- /dev/null +++ b/src/inference/postprocessor.py @@ -0,0 +1,118 @@ +""" +Post-processing utilities for generated text. +Handles cleanup, formatting, and final quality checks. +""" + +import re +from typing import List, Tuple +from loguru import logger + + +class PostProcessor: + """Cleans and formats generated text after model output.""" + + # Common generation artifacts to remove + ARTIFACTS = [ + r'', + r'', + r'', + r'', + r'\[PAD\]', + r'\[CLS\]', + r'\[SEP\]', + r'<\|endoftext\|>', + ] + + def __init__(self): + # Compile artifact removal regex + self._artifact_pattern = re.compile( + '|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS), + re.IGNORECASE + ) + + def clean(self, text: str) -> str: + """Remove generation artifacts and normalise whitespace.""" + if not text: + return "" + + # Remove generation artifacts + result = self._artifact_pattern.sub('', text) + + # Replace em dashes and en dashes with commas + result = result.replace('—', ',') + result = result.replace('–', ',') + + # Normalise whitespace + result = re.sub(r'\s+', ' ', result) + result = result.strip() + + # Fix common post-generation spacing issues + result = re.sub(r'\s+([.,!?;:])', r'\1', result) # Remove space before punctuation + result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) # Add space after punctuation + result = re.sub(r'\(\s+', '(', result) # Remove space after opening paren + result = re.sub(r'\s+\)', ')', result) # Remove space before closing paren + + # Fix multiple punctuation + result = re.sub(r'\.{2,}', '.', result) + result = re.sub(r'\?{2,}', '?', result) + result = re.sub(r'!{2,}', '!', result) + + return result + + def restore_entities( + self, + text: str, + original_entities: List[str], + protected_spans: List[Tuple[int, int]], + ) -> str: + """Restore named entities that may have been altered during generation. + + Uses fuzzy matching to find where entities should be in the generated text + and restores the original form. + """ + if not original_entities: + return text + + result = text + for entity in original_entities: + # Check if entity is already present in correct form + if entity in result: + continue + + # Try case-insensitive match + pattern = re.compile(re.escape(entity), re.IGNORECASE) + if pattern.search(result): + result = pattern.sub(entity, result, count=1) + logger.debug(f"Restored entity: {entity}") + + return result + + def format_output(self, text: str) -> str: + """Apply final formatting (capitalisation, punctuation, spacing).""" + if not text: + return "" + + result = text.strip() + + # Ensure first letter is capitalised + if result and result[0].islower(): + result = result[0].upper() + result[1:] + + # Ensure text ends with punctuation + if result and result[-1] not in '.!?': + result += '.' + + # Capitalise after sentence-ending punctuation + result = re.sub( + r'([.!?]\s+)([a-z])', + lambda m: m.group(1) + m.group(2).upper(), + result + ) + + # Fix "i" → "I" when standalone + result = re.sub(r'\bi\b', 'I', result) + + # Remove trailing whitespace from lines + result = '\n'.join(line.rstrip() for line in result.split('\n')) + + return result diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/__pycache__/__init__.cpython-312.pyc b/src/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7d4d3e48e4d3b1a34d8b01aaee58b89ce921027 Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/model/__pycache__/__init__.cpython-314.pyc b/src/model/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee3b025bcc53e9eb85c678490cfcbea8886d75a0 Binary files /dev/null and b/src/model/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/model/__pycache__/base_model.cpython-312.pyc b/src/model/__pycache__/base_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa822a2839a13143a2f2461ebef23e6c6132f557 Binary files /dev/null and b/src/model/__pycache__/base_model.cpython-312.pyc differ diff --git a/src/model/__pycache__/base_model.cpython-314.pyc b/src/model/__pycache__/base_model.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..693923c38004e383c3e1bf607c06aa00acb91e56 Binary files /dev/null and b/src/model/__pycache__/base_model.cpython-314.pyc differ diff --git a/src/model/__pycache__/generation_utils.cpython-312.pyc b/src/model/__pycache__/generation_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04d29bb104e002c6192067b1d85a0bbb4e382ef1 Binary files /dev/null and b/src/model/__pycache__/generation_utils.cpython-312.pyc differ diff --git a/src/model/__pycache__/generation_utils.cpython-314.pyc b/src/model/__pycache__/generation_utils.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5854db2e2afea36ab4825bc968c0f997d30f20b1 Binary files /dev/null and b/src/model/__pycache__/generation_utils.cpython-314.pyc differ diff --git a/src/model/__pycache__/lora_adapter.cpython-312.pyc b/src/model/__pycache__/lora_adapter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eceda5cea65401d40d8d2b34b7ece3dc017970f Binary files /dev/null and b/src/model/__pycache__/lora_adapter.cpython-312.pyc differ diff --git a/src/model/__pycache__/style_conditioner.cpython-312.pyc b/src/model/__pycache__/style_conditioner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e872017c9168654f6e6c620f9d9548c23ad4ea30 Binary files /dev/null and b/src/model/__pycache__/style_conditioner.cpython-312.pyc differ diff --git a/src/model/__pycache__/style_conditioner.cpython-314.pyc b/src/model/__pycache__/style_conditioner.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7338117b6f406bec2d1bf068f7e8056bb87dfdb Binary files /dev/null and b/src/model/__pycache__/style_conditioner.cpython-314.pyc differ diff --git a/src/model/base_model.py b/src/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..74e0a0b1707f2cd1a4827d4b82b1e8d8bae64771 --- /dev/null +++ b/src/model/base_model.py @@ -0,0 +1,135 @@ +""" +Loads and wraps the base pretrained model. +Supported architectures: + - google/flan-t5-xl (recommended, 3B) + - google/flan-t5-large (780M, resource-constrained) + - facebook/bart-large (400M, excellent denoiser) + - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality) +""" + +from transformers import ( + AutoTokenizer, AutoModelForSeq2SeqLM, + AutoModelForCausalLM, BitsAndBytesConfig +) +from peft import get_peft_model, LoraConfig, TaskType +import torch +from loguru import logger + + +ENCODER_DECODER_MODELS = { + "flan-t5-xl": "google/flan-t5-xl", + "flan-t5-large": "google/flan-t5-large", + "flan-t5-base": "google/flan-t5-base", + "flan-t5-small": "google/flan-t5-small", + "bart-large": "facebook/bart-large", +} + +DECODER_ONLY_MODELS = { + "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", +} + + +def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True, + lora_config_dict: dict = None): + """ + Load a pretrained model with optional LoRA and quantization. + + Args: + model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS + quantize: Whether to use 4-bit quantization + use_lora: Whether to apply LoRA adapters + lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.) + + Returns: + Tuple of (model, tokenizer, is_seq2seq) + """ + # Determine model type and HuggingFace identifier + is_seq2seq = model_key in ENCODER_DECODER_MODELS + is_causal = model_key in DECODER_ONLY_MODELS + + if not is_seq2seq and not is_causal: + raise ValueError( + f"Unknown model key: '{model_key}'. " + f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}" + ) + + model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key) + logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Configure quantization if requested + model_kwargs = { + "torch_dtype": torch.float32, # CPU-optimised: use float32 for stability + } + + # Detect device + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + # Use bfloat16 if Ampere+, else float16 + if torch.cuda.get_device_capability()[0] >= 8: + model_kwargs["torch_dtype"] = torch.bfloat16 + else: + model_kwargs["torch_dtype"] = torch.float16 + + if quantize and device == "cuda": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=model_kwargs["torch_dtype"], + bnb_4bit_use_double_quant=True, + ) + model_kwargs["quantization_config"] = bnb_config + logger.info("Using 4-bit NF4 quantization") + elif quantize and device == "cpu": + logger.warning("Quantization requested but no GPU available, skipping") + + # Load model + if is_seq2seq: + model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + + # Move to device if not quantized (quantized models are already on device) + if not quantize or device == "cpu": + model = model.to(device) + + logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}") + + # Apply LoRA if requested + if use_lora: + lora_cfg = lora_config_dict or {} + task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM + + # Default target modules based on model architecture + default_targets = { + "flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], + "flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], + "flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], + "flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], + "bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"], + "llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + } + + lora_config = LoraConfig( + task_type=task_type, + r=lora_cfg.get("r", 16), + lora_alpha=lora_cfg.get("lora_alpha", 32), + lora_dropout=lora_cfg.get("lora_dropout", 0.05), + target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])), + bias="none", + ) + + model = get_peft_model(model, lora_config) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in model.parameters()) + logger.info( + f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total " + f"({100 * trainable_params / total_params:.2f}%)" + ) + + return model, tokenizer, is_seq2seq diff --git a/src/model/generation_utils.py b/src/model/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..724c0a2df2292d76addfa9f57276e167590780e0 --- /dev/null +++ b/src/model/generation_utils.py @@ -0,0 +1,106 @@ +""" +Generation utilities for text correction. +Handles beam search, constrained decoding, and post-generation cleanup. +""" + +import torch +from transformers import PreTrainedModel, PreTrainedTokenizer +from typing import Dict, Optional, List +from loguru import logger + + +def generate_correction( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + generation_config: Dict, +) -> str: + """Generate corrected text from input tokens.""" + # Build generation kwargs from config + gen_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "max_new_tokens": generation_config.get("max_new_tokens", 512), + "num_beams": generation_config.get("num_beams", 5), + "length_penalty": generation_config.get("length_penalty", 1.0), + "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), + "min_length": generation_config.get("min_length", 10), + "early_stopping": generation_config.get("early_stopping", True), + } + + # Optional sampling parameters + if generation_config.get("do_sample", False): + gen_kwargs["do_sample"] = True + gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) + gen_kwargs["top_p"] = generation_config.get("top_p", 0.9) + else: + gen_kwargs["do_sample"] = False + + with torch.no_grad(): + output_ids = model.generate(**gen_kwargs) + + # Decode, skipping special tokens + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + return generated_text.strip() + + +def batch_generate( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + texts: List[str], + generation_config: Dict, + max_length: int = 512, +) -> List[str]: + """Generate corrections for a batch of texts.""" + if not texts: + return [] + + results = [] + # Process in mini-batches to manage memory on CPU + batch_size = generation_config.get("batch_size", 4) + + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + + # Tokenise batch + inputs = tokenizer( + batch_texts, + max_length=max_length, + padding=True, + truncation=True, + return_tensors="pt", + ) + + # Move to model device + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Generate + gen_kwargs = { + "max_new_tokens": generation_config.get("max_new_tokens", 512), + "num_beams": generation_config.get("num_beams", 5), + "length_penalty": generation_config.get("length_penalty", 1.0), + "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), + "early_stopping": generation_config.get("early_stopping", True), + } + + if generation_config.get("do_sample", False): + gen_kwargs["do_sample"] = True + gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) + + with torch.no_grad(): + output_ids = model.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + + # Decode each output + for output in output_ids: + text = tokenizer.decode(output, skip_special_tokens=True) + results.append(text.strip()) + + logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts") + + return results diff --git a/src/model/lora_adapter.py b/src/model/lora_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2b69016d55a47e6aec84931e05b66f6795417d --- /dev/null +++ b/src/model/lora_adapter.py @@ -0,0 +1,54 @@ +""" +LoRA adapter configuration and management. +Wraps PEFT LoRA utilities for applying parameter-efficient +fine-tuning to the base model. +""" + +from peft import LoraConfig, TaskType, get_peft_model +from typing import List, Optional +from loguru import logger + + +def create_lora_config( + task_type: TaskType, + r: int = 16, + lora_alpha: int = 32, + target_modules: Optional[List[str]] = None, + lora_dropout: float = 0.05, +) -> LoraConfig: + """Create a LoRA configuration for the given task type.""" + if target_modules is None: + target_modules = ["q", "v"] + + config = LoraConfig( + task_type=task_type, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules, + bias="none", + inference_mode=False, + ) + logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}") + return config + + +def apply_lora(model, lora_config: LoraConfig): + """Apply LoRA adapters to a model and return the wrapped model.""" + peft_model = get_peft_model(model, lora_config) + trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad) + total = sum(p.numel() for p in peft_model.parameters()) + logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)") + return peft_model + + +def merge_lora_weights(model): + """Merge LoRA weights into the base model for inference. + + After merging, the model behaves like a regular model with + LoRA modifications baked in, removing the adapter overhead. + """ + logger.info("Merging LoRA weights into base model...") + merged = model.merge_and_unload() + logger.info("LoRA weights merged successfully") + return merged diff --git a/src/model/style_conditioner.py b/src/model/style_conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..6b176be45ee8095bfd0053b94fa9222a4a8eb529 --- /dev/null +++ b/src/model/style_conditioner.py @@ -0,0 +1,74 @@ +""" +Injects the style vector into the model via soft prompt conditioning. +The style vector is projected to the model's hidden dimension and +prepended to the input token embeddings as virtual tokens. + +This technique is called "prefix tuning" / "style prefix injection". +It biases the model's attention toward the desired output style +without modifying the base model weights. + +For Flan-T5: injects into encoder input embeddings +For BART: injects into encoder input embeddings +For Llama: prepends to the full input context +""" + +import torch +import torch.nn as nn + + +class StyleConditioner(nn.Module): + """ + Projects a 512-dim style vector to n_prefix_tokens virtual tokens + in the model's embedding space. + """ + + def __init__( + self, + style_dim: int = 512, + model_hidden_dim: int = 512, # T5-Small=512, Base=768, Large=1024, XL=2048 + n_prefix_tokens: int = 10, # Number of virtual prefix tokens + ): + super().__init__() + self.style_dim = style_dim + self.model_hidden_dim = model_hidden_dim + self.n_prefix_tokens = n_prefix_tokens + + # Project style vector to prefix embeddings + # style_dim → n_prefix_tokens * model_hidden_dim + total_output_dim = n_prefix_tokens * model_hidden_dim + self.projection = nn.Sequential( + nn.Linear(style_dim, total_output_dim), + nn.Tanh(), + ) + + def forward(self, style_vector: torch.Tensor) -> torch.Tensor: + """ + Args: + style_vector: [batch_size, 512] + Returns: + prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim] + """ + # Project: [batch, 512] → [batch, n_prefix * hidden_dim] + projected = self.projection(style_vector) + + # Reshape: [batch, n_prefix * hidden_dim] → [batch, n_prefix, hidden_dim] + batch_size = style_vector.size(0) + prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim) + + return prefix_embeddings + + +def prepend_style_prefix( + input_embeddings: torch.Tensor, + style_prefix: torch.Tensor, +) -> torch.Tensor: + """ + Concatenates style prefix to input embeddings along sequence dimension. + + Args: + input_embeddings: [batch, seq_len, hidden_dim] + style_prefix: [batch, n_prefix, hidden_dim] + Returns: + [batch, n_prefix + seq_len, hidden_dim] + """ + return torch.cat([style_prefix, input_embeddings], dim=1) diff --git a/src/preprocessing/__init__.py b/src/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/preprocessing/__pycache__/__init__.cpython-312.pyc b/src/preprocessing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ab54542e1cf18a2887279f74f2eaf352a120125 Binary files /dev/null and b/src/preprocessing/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/__init__.cpython-314.pyc b/src/preprocessing/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..216a0387c098d74e34d7c87650447d52a4fd9b1f Binary files /dev/null and b/src/preprocessing/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc b/src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c435e85eb715a5a3470feb5450550d9f084d98 Binary files /dev/null and b/src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc b/src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..596f38f92d05aaa2dae0fa1e4e8ca9a030acc462 Binary files /dev/null and b/src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc b/src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a4c1c22252674f3262dfe9694b2abad34b60be7 Binary files /dev/null and b/src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc differ diff --git a/src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc b/src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85ea200e90384d39bcf3f4c38e6727668f17bfcb Binary files /dev/null and b/src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/pipeline.cpython-312.pyc b/src/preprocessing/__pycache__/pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c7e6a5131bb6f373bd4c3447c18a9daf49eb34 Binary files /dev/null and b/src/preprocessing/__pycache__/pipeline.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/pipeline.cpython-314.pyc b/src/preprocessing/__pycache__/pipeline.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907e1d90e14df0bc82a5c83cc178ca7b7ace70df Binary files /dev/null and b/src/preprocessing/__pycache__/pipeline.cpython-314.pyc differ diff --git a/src/preprocessing/__pycache__/sentence_segmenter.cpython-312.pyc b/src/preprocessing/__pycache__/sentence_segmenter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aadabf886a1af431a1f0ff75837d4a67b494b981 Binary files /dev/null and b/src/preprocessing/__pycache__/sentence_segmenter.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/spell_corrector.cpython-312.pyc b/src/preprocessing/__pycache__/spell_corrector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d548f2d18adabc77b462b248c8b5c8def0c3ee48 Binary files /dev/null and b/src/preprocessing/__pycache__/spell_corrector.cpython-312.pyc differ diff --git a/src/preprocessing/__pycache__/spell_corrector.cpython-314.pyc b/src/preprocessing/__pycache__/spell_corrector.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e1a2b509ab808bd63e7b10dff3f01de3ce9eb48 Binary files /dev/null and b/src/preprocessing/__pycache__/spell_corrector.cpython-314.pyc differ diff --git a/src/preprocessing/dependency_parser.py b/src/preprocessing/dependency_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1cdc8cac1ec4a49ca3db649ba0975a8a0618de63 --- /dev/null +++ b/src/preprocessing/dependency_parser.py @@ -0,0 +1,72 @@ +""" +Dependency parser module. +Extracts grammatical skeletons (subject-verb-object) from sentences +using spaCy's dependency parse trees. +""" + +import spacy +from typing import Dict, List, Any +from loguru import logger + + +class DependencyParser: + """Extracts dependency trees and SVO triples from text.""" + + def __init__(self, model_name: str = "en_core_web_trf"): + try: + self.nlp = spacy.load(model_name) + except OSError: + logger.warning(f"spaCy model '{model_name}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + + def parse(self, text: str) -> List[Dict[str, Any]]: + """Extract dependency tree for each sentence.""" + if not text or not text.strip(): + return [] + doc = self.nlp(text) + trees = [] + for sent in doc.sents: + tokens = [] + for token in sent: + tokens.append({ + "text": token.text, + "lemma": token.lemma_, + "pos": token.pos_, + "dep": token.dep_, + "head": token.head.text, + "head_idx": token.head.i - sent.start, + "children": [child.text for child in token.children], + }) + trees.append({ + "sentence": sent.text, + "tokens": tokens, + "root": sent.root.text, + }) + return trees + + def extract_svo(self, text: str) -> List[Dict[str, List[str]]]: + """Extract subject-verb-object triples per sentence.""" + if not text or not text.strip(): + return [] + doc = self.nlp(text) + results = [] + for sent in doc.sents: + subjects = [] + verbs = [] + objects = [] + for token in sent: + if token.dep_ in ("nsubj", "nsubjpass"): + subjects.append(token.text) + # The head of nsubj is typically the verb + if token.head.pos_ == "VERB": + verbs.append(token.head.text) + elif token.dep_ in ("dobj", "pobj", "attr"): + objects.append(token.text) + # Deduplicate verbs + verbs = list(dict.fromkeys(verbs)) + results.append({ + "subjects": subjects, + "verbs": verbs, + "objects": objects, + }) + return results diff --git a/src/preprocessing/dyslexia_simulator.py b/src/preprocessing/dyslexia_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..c3887bcf91d1a90ab8ccf43c21ae60ab2ea86c90 --- /dev/null +++ b/src/preprocessing/dyslexia_simulator.py @@ -0,0 +1,133 @@ +""" +Programmatically generates dyslectic training data from clean text. +Used to augment training pairs when real dyslectic examples are scarce. + +Error types simulated (from Rello et al. 2013, 2017 dyslexia research): +- Phonetic substitution (most common, ~35% of errors) +- Letter transposition (e.g., "teh" for "the") (~18%) +- Letter omission (~16%) +- Letter doubling (~12%) +- Letter reversal b/d, p/q (~10%) +- Word boundary errors (~9%) +""" + +import random +import re +from typing import Tuple + + +class DyslexiaSimulator: + """Generates synthetic dyslectic text from clean input for data augmentation.""" + + LETTER_REVERSALS = {'b': 'd', 'd': 'b', 'p': 'q', 'q': 'p', 'n': 'u', 'u': 'n'} + PHONETIC_SUBS = { + 'was': 'wuz', 'could': 'cud', 'would': 'wud', 'they': 'thay', + 'because': 'becaus', 'important': 'importnt', 'receive': 'recieve', + 'believe': 'beleive', 'definitely': 'definately', 'separate': 'seperate', + 'a lot': 'alot', 'in fact': 'infact', 'as well': 'aswell', + } + WORD_MERGES = [ + ('a lot', 'alot'), ('in fact', 'infact'), ('as well', 'aswell'), + ('all right', 'alright'), ('every one', 'everyone'), + ] + + def __init__(self, error_rate: float = 0.15, seed: int = 42): + self.error_rate = error_rate + self.rng = random.Random(seed) + + def _transpose_letters(self, word: str) -> str: + """Swap two adjacent letters.""" + if len(word) < 3: + return word + # Pick a random position in interior of word (not first/last) + idx = self.rng.randint(1, len(word) - 2) + chars = list(word) + chars[idx], chars[idx + 1] = chars[idx + 1], chars[idx] + return ''.join(chars) + + def _omit_letter(self, word: str) -> str: + """Remove a random interior letter.""" + if len(word) < 4: + return word + idx = self.rng.randint(1, len(word) - 2) + return word[:idx] + word[idx + 1:] + + def _double_letter(self, word: str) -> str: + """Double a random interior letter.""" + if len(word) < 3: + return word + idx = self.rng.randint(1, len(word) - 2) + return word[:idx] + word[idx] + word[idx:] + + def _reverse_letter(self, word: str) -> str: + """Swap b/d, p/q style reversals.""" + chars = list(word) + reversed_any = False + for i, c in enumerate(chars): + if c.lower() in self.LETTER_REVERSALS: + replacement = self.LETTER_REVERSALS[c.lower()] + # Preserve case + chars[i] = replacement.upper() if c.isupper() else replacement + reversed_any = True + break # Only reverse one letter per word + if reversed_any: + return ''.join(chars) + return word + + def corrupt_word(self, word: str) -> str: + """Apply a single random error to a word.""" + if len(word) < 3: + return word + # Check for phonetic substitution first + lower = word.lower() + if lower in self.PHONETIC_SUBS and self.rng.random() < 0.35: + sub = self.PHONETIC_SUBS[lower] + return sub.capitalize() if word[0].isupper() else sub + + # Weighted random error selection matching research distributions + error_type = self.rng.choices( + ['transpose', 'omit', 'double', 'reverse'], + weights=[0.35, 0.30, 0.20, 0.15], + k=1 + )[0] + + if error_type == 'transpose': + return self._transpose_letters(word) + elif error_type == 'omit': + return self._omit_letter(word) + elif error_type == 'double': + return self._double_letter(word) + else: + return self._reverse_letter(word) + + def simulate(self, clean_text: str) -> Tuple[str, str]: + """Returns (corrupted_text, clean_text) training pair.""" + if not clean_text or not clean_text.strip(): + return (clean_text, clean_text) + + # First, apply word merge errors at phrase level + corrupted = clean_text + for original_phrase, merged in self.WORD_MERGES: + if original_phrase in corrupted.lower() and self.rng.random() < self.error_rate: + # Case-insensitive replacement + pattern = re.compile(re.escape(original_phrase), re.IGNORECASE) + corrupted = pattern.sub(merged, corrupted, count=1) + + # Then corrupt individual words + words = corrupted.split() + corrupted_words = [] + for word in words: + # Strip trailing punctuation for corruption, reattach after + stripped = word.rstrip(".,!?;:\"'()[]{}—–-") + suffix = word[len(stripped):] + + if (len(stripped) >= 3 and + self.rng.random() < self.error_rate and + stripped.isalpha()): + corrupted_word = self.corrupt_word(stripped) + corrupted_words.append(corrupted_word + suffix) + else: + corrupted_words.append(word) + + corrupted = ' '.join(corrupted_words) + return (corrupted, clean_text) diff --git a/src/preprocessing/ner_tagger.py b/src/preprocessing/ner_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..517c46bd9c7334947454b64fb87c9cb623817cf0 --- /dev/null +++ b/src/preprocessing/ner_tagger.py @@ -0,0 +1,49 @@ +""" +Named Entity Recognition tagger. +Identifies entities (persons, locations, organisations, etc.) that +must be protected from modification during the correction process. +""" + +import spacy +from typing import List, Tuple +from dataclasses import dataclass +from loguru import logger + + +@dataclass +class EntitySpan: + text: str + label: str + start_char: int + end_char: int + + +class NERTagger: + """Tags named entities and produces protected spans.""" + + def __init__(self, model_name: str = "en_core_web_trf"): + try: + self.nlp = spacy.load(model_name) + except OSError: + logger.warning(f"spaCy model '{model_name}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + + def tag(self, text: str) -> List[EntitySpan]: + """Extract all named entities from text.""" + if not text or not text.strip(): + return [] + doc = self.nlp(text) + entities = [] + for ent in doc.ents: + entities.append(EntitySpan( + text=ent.text, + label=ent.label_, + start_char=ent.start_char, + end_char=ent.end_char, + )) + return entities + + def get_protected_spans(self, text: str) -> List[Tuple[int, int]]: + """Return (start, end) char spans that must not be modified.""" + entities = self.tag(text) + return [(e.start_char, e.end_char) for e in entities] diff --git a/src/preprocessing/pipeline.py b/src/preprocessing/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f071fdb7985e026e8d241544761b0c29e92c2f --- /dev/null +++ b/src/preprocessing/pipeline.py @@ -0,0 +1,163 @@ +""" +Master pre-processing pipeline. Runs all NLP stages in sequence. +Returns a PreprocessedDoc object with all annotations attached. +""" + +import spacy +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional +from .spell_corrector import DyslexiaAwareSpellCorrector +import textstat +from loguru import logger + + +@dataclass +class EntitySpan: + text: str + label: str + start_char: int + end_char: int + + +@dataclass +class PreprocessedDoc: + original_text: str + corrected_text: str + sentences: List[str] + entities: List[EntitySpan] # Never to be modified by rewriter + dependency_trees: List[Dict] # Grammatical skeletons per sentence + pos_tags: List[List[tuple]] # (token, POS) per sentence + readability: Dict[str, float] # Flesch-Kincaid, Gunning Fog, etc. + sentence_lengths: List[int] + protected_spans: List[tuple] # (start, end) char spans to never touch + + +class PreprocessingPipeline: + """Orchestrates all pre-processing stages: spell correction, parsing, NER, readability.""" + + def __init__(self, model_name: str = "en_core_web_trf"): + # Load spaCy model with fallback + try: + self.nlp = spacy.load(model_name) + except OSError: + logger.warning(f"spaCy model '{model_name}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + + # Initialise spell corrector + self.spell_corrector = DyslexiaAwareSpellCorrector() + logger.info("PreprocessingPipeline initialised") + + def _extract_readability(self, text: str) -> Dict[str, float]: + """Compute readability scores (Flesch-Kincaid, Gunning Fog, etc.).""" + if not text or not text.strip(): + return { + "flesch_kincaid_grade": 0.0, + "gunning_fog": 0.0, + "smog_index": 0.0, + "automated_readability_index": 0.0, + "flesch_reading_ease": 0.0, + "coleman_liau_index": 0.0, + } + return { + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "gunning_fog": textstat.gunning_fog(text), + "smog_index": textstat.smog_index(text), + "automated_readability_index": textstat.automated_readability_index(text), + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + } + + def _extract_dep_tree(self, sent) -> Dict: + """Extract grammatical skeleton: subject-verb-object per sentence.""" + subjects = [] + verbs = [] + objects = [] + for token in sent: + if token.dep_ in ("nsubj", "nsubjpass"): + subjects.append(token.text) + if token.head.pos_ == "VERB": + verbs.append(token.head.text) + elif token.dep_ in ("dobj", "pobj", "attr"): + objects.append(token.text) + return { + "sentence": sent.text, + "subjects": subjects, + "verbs": list(dict.fromkeys(verbs)), + "objects": objects, + "root": sent.root.text if sent.root else "", + } + + def process(self, raw_text: str) -> PreprocessedDoc: + """Run full pre-processing pipeline on raw text. + + 7-step pipeline: + 1. Spell correction (phonetic + spellcheck + grammar) + 2. spaCy parsing + 3. Sentence segmentation + 4. Named entity recognition + 5. Dependency tree extraction + 6. POS tagging + 7. Readability scoring + """ + if not raw_text or not raw_text.strip(): + return PreprocessedDoc( + original_text=raw_text, + corrected_text=raw_text or "", + sentences=[], + entities=[], + dependency_trees=[], + pos_tags=[], + readability=self._extract_readability(""), + sentence_lengths=[], + protected_spans=[], + ) + + # Step 1: Spell correction + corrected = self.spell_corrector.correct(raw_text) + + # Step 2: Parse corrected text with spaCy + doc = self.nlp(corrected) + + # Step 3: Sentence segmentation + sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()] + + # Step 4: NER — extract entities and protected spans + entities = [] + protected_spans = [] + for ent in doc.ents: + entities.append(EntitySpan( + text=ent.text, + label=ent.label_, + start_char=ent.start_char, + end_char=ent.end_char, + )) + protected_spans.append((ent.start_char, ent.end_char)) + + # Step 5: Dependency trees per sentence + dependency_trees = [] + for sent in doc.sents: + dependency_trees.append(self._extract_dep_tree(sent)) + + # Step 6: POS tags per sentence + pos_tags = [] + for sent in doc.sents: + sent_tags = [(token.text, token.pos_) for token in sent] + pos_tags.append(sent_tags) + + # Step 7: Readability + readability = self._extract_readability(corrected) + + # Sentence lengths + sentence_lengths = [len(s.split()) for s in sentences] + + return PreprocessedDoc( + original_text=raw_text, + corrected_text=corrected, + sentences=sentences, + entities=entities, + dependency_trees=dependency_trees, + pos_tags=pos_tags, + readability=readability, + sentence_lengths=sentence_lengths, + protected_spans=protected_spans, + ) diff --git a/src/preprocessing/sentence_segmenter.py b/src/preprocessing/sentence_segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..68f67c45ea1f6eda7479e28f681113be33787f08 --- /dev/null +++ b/src/preprocessing/sentence_segmenter.py @@ -0,0 +1,29 @@ +""" +Sentence segmentation module. +Uses spaCy's sentence boundary detection for accurate segmentation +of potentially malformed dyslectic text. +""" + +import spacy +from typing import List +from loguru import logger + + +class SentenceSegmenter: + """Segments text into sentences using spaCy's transformer model.""" + + def __init__(self, model_name: str = "en_core_web_trf"): + try: + self.nlp = spacy.load(model_name) + except OSError: + logger.warning(f"spaCy model '{model_name}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + logger.info(f"SentenceSegmenter loaded with model: {self.nlp.meta['name']}") + + def segment(self, text: str) -> List[str]: + """Split text into individual sentences.""" + if not text or not text.strip(): + return [] + doc = self.nlp(text) + sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()] + return sentences diff --git a/src/preprocessing/spell_corrector.py b/src/preprocessing/spell_corrector.py new file mode 100644 index 0000000000000000000000000000000000000000..6da338024159671fd9ee63f8fce62eb8bf9c32fd --- /dev/null +++ b/src/preprocessing/spell_corrector.py @@ -0,0 +1,133 @@ +""" +Two-pass spell correction: +Pass 1: pyspellchecker (fast, context-free, catches simple typos) +Pass 2: LanguageTool (context-aware, catches grammar + dyslexic patterns) + +Dyslexic error patterns handled: +- Letter reversals: b/d, p/q, n/u, m/w +- Phonetic spelling: "wuz", "cud", "thay" +- Word boundary errors: "alot", "infact", "aswell" +- Letter omissions: "becaus", "importnt" +- Letter transpositions: "teh", "recieve" +- Homophone confusion: there/their/they're +""" + +import language_tool_python +from spellchecker import SpellChecker +from loguru import logger +from typing import Optional +import re + + +class DyslexiaAwareSpellCorrector: + """Two-pass spell corrector with dyslexia-specific phonetic pattern handling.""" + + DYSLEXIC_PHONETIC_MAP = { + "wuz": "was", "cud": "could", "wud": "would", "shud": "should", + "thay": "they", "thier": "their", "recieve": "receive", + "beleive": "believe", "occured": "occurred", "definately": "definitely", + "seperate": "separate", "untill": "until", "tommorrow": "tomorrow", + "alot": "a lot", "infact": "in fact", "aswell": "as well", + "alright": "all right", "cant": "cannot", "wont": "will not", + "ive": "I have", "im": "I am", "id": "I would", + } + + def __init__(self, language: str = "en-US"): + self.spell = SpellChecker() + self.language = language + # Build regex pattern for phonetic map (word-boundary matching) + self._phonetic_pattern = re.compile( + r'\b(' + '|'.join(re.escape(k) for k in self.DYSLEXIC_PHONETIC_MAP.keys()) + r')\b', + re.IGNORECASE + ) + # Try to initialise LanguageTool; graceful fallback if JVM not available + self.tool = None + try: + self.tool = language_tool_python.LanguageTool(language) + logger.info("LanguageTool initialised successfully") + except Exception as e: + logger.warning(f"LanguageTool unavailable (JVM issue?), skipping context-aware pass: {e}") + + def _phonetic_pass(self, text: str) -> str: + """Apply known dyslexic phonetic substitutions first.""" + def _replace(match): + word = match.group(0) + lower = word.lower() + replacement = self.DYSLEXIC_PHONETIC_MAP.get(lower, word) + # Preserve capitalisation of first letter + if word[0].isupper() and replacement[0].islower(): + replacement = replacement[0].upper() + replacement[1:] + return replacement + + return self._phonetic_pattern.sub(_replace, text) + + def _spellcheck_pass(self, text: str) -> str: + """pyspellchecker pass for simple token-level errors.""" + words = text.split() + corrected_words = [] + for word in words: + # Strip punctuation for checking but preserve it + stripped = word.strip(".,!?;:\"'()[]{}—–-") + prefix = word[:len(word) - len(word.lstrip(".,!?;:\"'()[]{}—–-"))] + suffix = word[len(stripped) + len(prefix):] + + if stripped and stripped.lower() not in self.spell and not stripped.isupper(): + correction = self.spell.correction(stripped.lower()) + if correction and correction != stripped.lower(): + # Preserve original capitalisation + if stripped[0].isupper(): + correction = correction.capitalize() + corrected_words.append(prefix + correction + suffix) + else: + corrected_words.append(word) + else: + corrected_words.append(word) + return " ".join(corrected_words) + + def _languagetool_pass(self, text: str) -> str: + """LanguageTool pass for context-aware grammar + spelling corrections.""" + if self.tool is None: + return text + + try: + matches = self.tool.check(text) + # Apply corrections in reverse order to preserve offsets + matches = sorted(matches, key=lambda m: m.offset, reverse=True) + result = text + for match in matches: + if match.replacements: + replacement = match.replacements[0] + start = match.offset + end = start + match.errorLength + result = result[:start] + replacement + result[end:] + return result + except Exception as e: + logger.warning(f"LanguageTool check failed: {e}") + return text + + def correct(self, text: str) -> str: + """Run all three correction passes in sequence.""" + if not text or not text.strip(): + return text + + logger.debug(f"Spell correction input: {text[:100]}...") + + # Pass 0: Phonetic substitutions (dyslexia-specific) + text = self._phonetic_pass(text) + + # Pass 1: Token-level spellcheck + text = self._spellcheck_pass(text) + + # Pass 2: Context-aware grammar correction + text = self._languagetool_pass(text) + + return text + + def close(self): + """Clean up LanguageTool resources.""" + if self.tool is not None: + try: + self.tool.close() + except Exception: + pass + self.tool = None diff --git a/src/style/__init__.py b/src/style/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/style/__pycache__/__init__.cpython-314.pyc b/src/style/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..557e5fd388cbc12437f732a2b6b831437249df8e Binary files /dev/null and b/src/style/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/style/__pycache__/emotion_classifier.cpython-312.pyc b/src/style/__pycache__/emotion_classifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea029162ad4517aff26800598038eeebf506a5b5 Binary files /dev/null and b/src/style/__pycache__/emotion_classifier.cpython-312.pyc differ diff --git a/src/style/__pycache__/fingerprinter.cpython-312.pyc b/src/style/__pycache__/fingerprinter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1949bf6782877cb1525ec6ee39a53148b7e487e Binary files /dev/null and b/src/style/__pycache__/fingerprinter.cpython-312.pyc differ diff --git a/src/style/__pycache__/fingerprinter.cpython-314.pyc b/src/style/__pycache__/fingerprinter.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46487d25ddc3f1eb58822e4ad4effc196147e3e Binary files /dev/null and b/src/style/__pycache__/fingerprinter.cpython-314.pyc differ diff --git a/src/style/__pycache__/formality_classifier.cpython-312.pyc b/src/style/__pycache__/formality_classifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..151171c29b0a5fd5e0b0ceda0338c156d89b28ab Binary files /dev/null and b/src/style/__pycache__/formality_classifier.cpython-312.pyc differ diff --git a/src/style/__pycache__/formality_classifier.cpython-314.pyc b/src/style/__pycache__/formality_classifier.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a31dbfe7203f86cc54224daf196bf5179c8004 Binary files /dev/null and b/src/style/__pycache__/formality_classifier.cpython-314.pyc differ diff --git a/src/style/__pycache__/style_vector.cpython-312.pyc b/src/style/__pycache__/style_vector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..983860ebe768294513a22c9c566ed7f324e704a0 Binary files /dev/null and b/src/style/__pycache__/style_vector.cpython-312.pyc differ diff --git a/src/style/emotion_classifier.py b/src/style/emotion_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..ef66ad798a840efd77800f54df6a3a3e86523dc8 --- /dev/null +++ b/src/style/emotion_classifier.py @@ -0,0 +1,76 @@ +""" +Emotion/register classifier module. +Classifies text emotional register (neutral, passionate, cautious, etc.). +Used as one dimension of the style fingerprint. +""" + +import re +from typing import Dict + + +class EmotionClassifier: + """Classifies emotional register of text using keyword-based analysis.""" + + REGISTER_KEYWORDS = { + "neutral": { + "states", "indicates", "shows", "reports", "notes", + "describes", "observed", "found", "results", "data", + "information", "according", "based", "study", "analysis", + }, + "passionate": { + "amazing", "incredible", "extraordinary", "remarkable", + "outstanding", "excellent", "wonderful", "brilliant", + "terrible", "devastating", "critical", "urgent", + "essential", "vital", "crucial", "imperative", + }, + "cautious": { + "perhaps", "possibly", "might", "may", "could", + "seem", "appears", "suggests", "indicates", "tend", + "potentially", "arguably", "presumably", "conceivably", + "tentatively", "provisionally", + }, + "analytical": { + "therefore", "consequently", "thus", "hence", "because", + "analysis", "examine", "investigate", "evaluate", "assess", + "compare", "contrast", "correlate", "determine", "evidence", + "hypothesis", "methodology", "framework", + }, + "confident": { + "clearly", "obviously", "certainly", "definitely", + "undoubtedly", "indeed", "absolutely", "demonstrate", + "prove", "establish", "confirm", "guarantee", + "unquestionably", "invariably", + }, + } + + def __init__(self): + pass + + def classify(self, text: str) -> Dict[str, float]: + """Return emotion distribution over register categories. + + Returns a dict with keys: neutral, passionate, cautious, analytical, confident. + Values are probabilities that sum to ~1.0. + """ + if not text or not text.strip(): + return {k: 0.2 for k in self.REGISTER_KEYWORDS} + + words = set(text.lower().split()) + scores = {} + + for register, keywords in self.REGISTER_KEYWORDS.items(): + overlap = len(words & keywords) + scores[register] = overlap + + # Add punctuation-based signals + exclamation_count = text.count("!") + question_count = text.count("?") + scores["passionate"] = scores.get("passionate", 0) + exclamation_count * 0.5 + scores["cautious"] = scores.get("cautious", 0) + question_count * 0.3 + + # Normalise to probability distribution + total = sum(scores.values()) + if total == 0: + return {k: 0.2 for k in self.REGISTER_KEYWORDS} + + return {k: v / total for k, v in scores.items()} diff --git a/src/style/fingerprinter.py b/src/style/fingerprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..2e744334d14c653deb9a66277016de6d269ab52b --- /dev/null +++ b/src/style/fingerprinter.py @@ -0,0 +1,349 @@ +""" +Extracts a numerical style vector from any text sample. +The style vector encodes the author's unique writing fingerprint +and is used both to condition the generation model and to evaluate +style preservation after correction. + +Style vector dimensions (total: 512 after projection): + Raw features (~40) → MLP projection → 512-dim dense vector + +Raw features: + - sentence_length_mean, sentence_length_std, sentence_length_skew [3] + - word_length_mean, word_length_std [2] + - type_token_ratio (TTR) [1] + - passive_voice_ratio [1] + - active_voice_ratio [1] + - subordinate_clause_ratio [1] + - avg_dependency_tree_depth [1] + - hedging_frequency (per 100 words) [1] + - discourse_marker_counts [however, therefore, moreover, ...] [20] + - formality_score (0-1) [1] + - lexical_density [1] + - nominalization_ratio [1] + - question_sentence_ratio [1] + - exclamation_ratio [1] + - first_person_ratio [1] + - third_person_ratio [1] + - academic_word_coverage [1] + - avg_syllables_per_word [1] + - flesch_reading_ease [1] +""" + +import spacy +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Dict, Optional +from scipy import stats +from loguru import logger + + +HEDGING_WORDS = { + "perhaps", "possibly", "probably", "might", "may", "could", "seem", + "appears", "suggests", "indicates", "tend", "often", "generally", + "approximately", "roughly", "somewhat", "relatively", "fairly", +} + +DISCOURSE_MARKERS = [ + "however", "therefore", "moreover", "furthermore", "consequently", + "nevertheless", "nonetheless", "additionally", "alternatively", + "subsequently", "previously", "similarly", "conversely", "thus", + "hence", "accordingly", "meanwhile", "indeed", "notably", "specifically", +] + +NOMINALISATION_SUFFIXES = ( + "tion", "sion", "ment", "ness", "ity", "ance", "ence", + "hood", "ship", "ism", "al", "ure", +) + +FEATURE_DIM = 41 # Fixed feature dimension for MLP input (3+2+1+1+1+1+1+1+20+1+1+1+1+1+1+1+1+1+1) + + +class StyleProjectionMLP(nn.Module): + """Projects raw feature vector to 512-dim style embedding.""" + + def __init__(self, input_dim: int = 41, hidden_dim: int = 256, output_dim: int = 512): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim, output_dim), + nn.LayerNorm(output_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class StyleFingerprinter: + """Extracts style fingerprint vectors from text samples.""" + + def __init__(self, spacy_model: str = "en_core_web_trf", awl_path: str = "data/awl/coxhead_awl.txt"): + # Load spaCy with fallback + try: + self.nlp = spacy.load(spacy_model) + except OSError: + logger.warning(f"spaCy model '{spacy_model}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + + # Load AWL + self.awl = self._load_awl(awl_path) + + # Projection MLP + self.projection = StyleProjectionMLP( + input_dim=FEATURE_DIM, hidden_dim=256, output_dim=512 + ) + self.projection.eval() # Start in eval mode; will be trained alongside main model + logger.info(f"StyleFingerprinter initialised (AWL size: {len(self.awl)})") + + def _load_awl(self, path: str) -> set: + """Load Academic Word List from file.""" + awl = set() + try: + with open(path) as f: + for line in f: + word = line.strip().lower() + if word: + awl.add(word) + except FileNotFoundError: + logger.warning(f"AWL file not found at {path}, using empty set") + return awl + + def _passive_voice_ratio(self, doc) -> float: + """Compute ratio of passive voice constructions.""" + passive_count = 0 + verb_count = 0 + for token in doc: + if token.pos_ == "VERB": + verb_count += 1 + if token.dep_ in ("nsubjpass", "auxpass"): + passive_count += 1 + if verb_count == 0: + return 0.0 + return passive_count / verb_count + + def _avg_dep_tree_depth(self, doc) -> float: + """Compute average dependency tree depth across all tokens.""" + def _depth(token): + d = 0 + current = token + while current.head != current: + d += 1 + current = current.head + if d > 50: # Safety limit + break + return d + + depths = [_depth(token) for token in doc if not token.is_punct] + if not depths: + return 0.0 + return sum(depths) / len(depths) + + def _lexical_density(self, doc) -> float: + """Compute ratio of content words to total words.""" + content_pos = {"NOUN", "VERB", "ADJ", "ADV"} + total = 0 + content = 0 + for token in doc: + if not token.is_punct and not token.is_space: + total += 1 + if token.pos_ in content_pos: + content += 1 + if total == 0: + return 0.0 + return content / total + + @staticmethod + def _count_syllables(word: str) -> int: + """Count syllables in a word using a vowel-group heuristic. + Avoids NLTK cmudict which has a known AssertionError bug.""" + word = word.lower().strip() + if not word: + return 1 + vowels = "aeiouy" + count = 0 + prev_vowel = False + for char in word: + is_vowel = char in vowels + if is_vowel and not prev_vowel: + count += 1 + prev_vowel = is_vowel + # Adjust for silent 'e' at end + if word.endswith("e") and count > 1: + count -= 1 + # Words like "the", "me" still need at least 1 + return max(count, 1) + + def _avg_syllables_per_word(self, words: list) -> float: + """Average syllables per word.""" + if not words: + return 0.0 + total = sum(self._count_syllables(w) for w in words) + return total / len(words) + + @staticmethod + def _flesch_reading_ease(words: list, sent_lengths: list) -> float: + """Compute Flesch Reading Ease score without textstat. + Formula: 206.835 - 1.015 * ASL - 84.6 * ASW + ASL = average sentence length, ASW = average syllables per word.""" + if not words or not sent_lengths: + return 0.0 + asl = sum(sent_lengths) / max(len(sent_lengths), 1) + vowels = "aeiouy" + total_syllables = 0 + for w in words: + w_lower = w.lower() + count = 0 + prev = False + for c in w_lower: + v = c in vowels + if v and not prev: + count += 1 + prev = v + if w_lower.endswith("e") and count > 1: + count -= 1 + total_syllables += max(count, 1) + asw = total_syllables / max(len(words), 1) + return 206.835 - 1.015 * asl - 84.6 * asw + + def extract_raw_features(self, text: str) -> Dict[str, float]: + """Extract ~40 raw style features from text.""" + if not text or not text.strip(): + return {f"f_{i}": 0.0 for i in range(FEATURE_DIM)} + + doc = self.nlp(text) + words = [t.text.lower() for t in doc if not t.is_punct and not t.is_space] + word_count = max(len(words), 1) + + # Sentence-level features + sentences = list(doc.sents) + sent_lengths = [len([t for t in s if not t.is_punct and not t.is_space]) for s in sentences] + if not sent_lengths: + sent_lengths = [0] + + features = {} + + # [3] Sentence length stats + features["sentence_length_mean"] = np.mean(sent_lengths) + features["sentence_length_std"] = np.std(sent_lengths) if len(sent_lengths) > 1 else 0.0 + features["sentence_length_skew"] = float(stats.skew(sent_lengths)) if len(sent_lengths) > 2 else 0.0 + + # [2] Word length stats + word_lengths = [len(w) for w in words] + features["word_length_mean"] = np.mean(word_lengths) if word_lengths else 0.0 + features["word_length_std"] = np.std(word_lengths) if len(word_lengths) > 1 else 0.0 + + # [1] Type-token ratio + unique_words = set(words) + features["type_token_ratio"] = len(unique_words) / word_count + + # [1] Passive voice ratio + features["passive_voice_ratio"] = self._passive_voice_ratio(doc) + + # [1] Active voice ratio + features["active_voice_ratio"] = 1.0 - features["passive_voice_ratio"] + + # [1] Subordinate clause ratio + sub_clauses = sum(1 for t in doc if t.dep_ in ("advcl", "relcl", "ccomp", "xcomp", "acl")) + features["subordinate_clause_ratio"] = sub_clauses / max(len(sent_lengths), 1) + + # [1] Avg dependency tree depth + features["avg_dependency_tree_depth"] = self._avg_dep_tree_depth(doc) + + # [1] Hedging frequency (per 100 words) + hedging_count = sum(1 for w in words if w in HEDGING_WORDS) + features["hedging_frequency"] = (hedging_count / word_count) * 100 + + # [20] Discourse marker counts (per 100 words) + for marker in DISCOURSE_MARKERS: + marker_count = words.count(marker) + features[f"discourse_{marker}"] = (marker_count / word_count) * 100 + + # [1] Formality score (cached classifier, not re-instantiated per call) + if not hasattr(self, '_formality_clf'): + from .formality_classifier import FormalityClassifier + self._formality_clf = FormalityClassifier() + features["formality_score"] = self._formality_clf.score(text) + + # [1] Lexical density + features["lexical_density"] = self._lexical_density(doc) + + # [1] Nominalization ratio + nom_count = sum(1 for w in words if any(w.endswith(s) for s in NOMINALISATION_SUFFIXES)) + features["nominalization_ratio"] = nom_count / word_count + + # [1] Question sentence ratio + question_sents = sum(1 for s in sentences if s.text.strip().endswith("?")) + features["question_sentence_ratio"] = question_sents / max(len(sentences), 1) + + # [1] Exclamation ratio + excl_sents = sum(1 for s in sentences if s.text.strip().endswith("!")) + features["exclamation_ratio"] = excl_sents / max(len(sentences), 1) + + # [1] First person ratio + first_person = {"i", "me", "my", "mine", "myself", "we", "our", "ours"} + fp_count = sum(1 for w in words if w in first_person) + features["first_person_ratio"] = fp_count / word_count + + # [1] Third person ratio + third_person = {"he", "she", "it", "they", "him", "her", "his", "its", "their", "them"} + tp_count = sum(1 for w in words if w in third_person) + features["third_person_ratio"] = tp_count / word_count + + # [1] Academic word coverage + academic_count = sum(1 for w in words if w in self.awl) + features["academic_word_coverage"] = academic_count / word_count + + # [1] Avg syllables per word (pure-Python, avoids NLTK cmudict bug) + features["avg_syllables_per_word"] = self._avg_syllables_per_word(words) + + # [1] Flesch reading ease (normalised to 0-1, pure-Python) + flesch = self._flesch_reading_ease(words, sent_lengths) + features["flesch_reading_ease"] = max(0.0, min(1.0, flesch / 100.0)) + + return features + + def extract_vector(self, text: str) -> torch.Tensor: + """Returns a 512-dim style embedding tensor.""" + features = self.extract_raw_features(text) + + # Convert feature dict to ordered float array + values = list(features.values()) + + # Pad or truncate to exactly FEATURE_DIM + if len(values) < FEATURE_DIM: + values.extend([0.0] * (FEATURE_DIM - len(values))) + else: + values = values[:FEATURE_DIM] + + # Convert to tensor and project through MLP + feature_tensor = torch.tensor(values, dtype=torch.float32).unsqueeze(0) + + with torch.no_grad(): + embedding = self.projection(feature_tensor) + + # L2 normalise + embedding = F.normalize(embedding, p=2, dim=-1) + + return embedding.squeeze(0) + + def blend_vectors( + self, + user_vec: torch.Tensor, + master_vec: Optional[torch.Tensor], + alpha: float = 0.6, + ) -> torch.Tensor: + """ + Blend user style with master copy style. + alpha = weight given to user's own style (0.6 = user dominates) + Formula: target = alpha * user_vec + (1 - alpha) * master_vec + """ + if master_vec is None: + return F.normalize(user_vec, p=2, dim=-1) + + blended = alpha * user_vec + (1 - alpha) * master_vec + # L2 normalise to unit sphere + return F.normalize(blended, p=2, dim=-1) diff --git a/src/style/formality_classifier.py b/src/style/formality_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1afa93114d5fb1eb7d33a211377b042adcd440fb --- /dev/null +++ b/src/style/formality_classifier.py @@ -0,0 +1,98 @@ +""" +Formality classifier module. +Classifies text on a 0-1 formality scale using linguistic features. +Used as one dimension of the style fingerprint. +""" + +import re +from typing import Optional + + +class FormalityClassifier: + """Scores text formality on a 0-1 scale using rule-based heuristics.""" + + # Informal markers that decrease formality score + CONTRACTIONS = { + "don't", "can't", "won't", "it's", "that's", "there's", + "they're", "we're", "you're", "i'm", "i've", "i'll", + "isn't", "aren't", "wasn't", "weren't", "hasn't", "haven't", + "couldn't", "wouldn't", "shouldn't", "let's", "he's", "she's", + } + + INFORMAL_WORDS = { + "gonna", "wanna", "gotta", "kinda", "sorta", "ya", "yeah", + "yep", "nope", "ok", "okay", "cool", "awesome", "stuff", + "things", "like", "basically", "actually", "literally", + "totally", "really", "super", "pretty", "kind of", "sort of", + } + + FORMAL_MARKERS = { + "furthermore", "moreover", "consequently", "nevertheless", + "nonetheless", "accordingly", "hence", "thus", "therefore", + "whereas", "notwithstanding", "hitherto", "whereby", + "therein", "thereof", "herein", + } + + def __init__(self): + pass + + def score(self, text: str) -> float: + """Return formality score in [0, 1]. Higher = more formal. + + Scoring based on: + - Contraction penalty (-0.05 each) + - Informal word penalty (-0.03 each) + - Formal marker bonus (+0.04 each) + - Average sentence length bonus (longer = more formal) + - First person penalty (-0.02 per occurrence) + - Exclamation penalty (-0.05 each) + """ + if not text or not text.strip(): + return 0.5 + + words = text.lower().split() + word_count = max(len(words), 1) + + # Base score + score = 0.5 + + # Contraction penalty + contraction_count = sum(1 for w in words if w in self.CONTRACTIONS) + score -= min(contraction_count * 0.05, 0.25) + + # Informal word penalty + informal_count = sum(1 for w in words if w in self.INFORMAL_WORDS) + score -= min((informal_count / word_count) * 0.5, 0.2) + + # Formal marker bonus + formal_count = sum(1 for w in words if w in self.FORMAL_MARKERS) + score += min(formal_count * 0.04, 0.2) + + # Sentence length bonus (longer sentences tend to be more formal) + sentences = [s.strip() for s in re.split(r'[.!?]+', text) if s.strip()] + if sentences: + avg_sent_len = sum(len(s.split()) for s in sentences) / len(sentences) + if avg_sent_len > 20: + score += 0.1 + elif avg_sent_len > 15: + score += 0.05 + elif avg_sent_len < 8: + score -= 0.05 + + # First person penalty (academic writing avoids "I") + first_person = sum(1 for w in words if w in ("i", "me", "my", "mine", "myself")) + score -= min((first_person / word_count) * 0.3, 0.1) + + # Exclamation penalty + exclamation_count = text.count("!") + score -= min(exclamation_count * 0.05, 0.15) + + # Question mark mild penalty (academic writing has fewer questions) + question_count = text.count("?") + score -= min(question_count * 0.02, 0.08) + + # Passive voice bonus (approximation: "is/was/were/been" + past participle patterns) + passive_indicators = sum(1 for w in words if w in ("is", "was", "were", "been", "being")) + score += min((passive_indicators / word_count) * 0.15, 0.1) + + return max(0.0, min(1.0, score)) diff --git a/src/style/style_vector.py b/src/style/style_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..e4228d445a0c19e75652bd94cb96e9bdea8384d8 --- /dev/null +++ b/src/style/style_vector.py @@ -0,0 +1,38 @@ +""" +Style vector utilities. +Helper functions for manipulating, comparing, and persisting style vectors. +""" + +import torch +import torch.nn.functional as F +from typing import List, Optional + + +def cosine_similarity(vec_a: torch.Tensor, vec_b: torch.Tensor) -> float: + """Compute cosine similarity between two style vectors.""" + if vec_a.dim() == 1: + vec_a = vec_a.unsqueeze(0) + if vec_b.dim() == 1: + vec_b = vec_b.unsqueeze(0) + sim = F.cosine_similarity(vec_a, vec_b, dim=-1) + return sim.item() + + +def average_style_vectors(vectors: List[torch.Tensor]) -> torch.Tensor: + """Compute the mean style vector from a list of vectors.""" + if not vectors: + raise ValueError("Cannot average empty list of vectors") + stacked = torch.stack(vectors, dim=0) + mean_vec = stacked.mean(dim=0) + # L2 normalise the result + return F.normalize(mean_vec, p=2, dim=-1) + + +def save_style_vector(vector: torch.Tensor, path: str) -> None: + """Persist a style vector to disk.""" + torch.save(vector.detach().cpu(), path) + + +def load_style_vector(path: str) -> torch.Tensor: + """Load a style vector from disk.""" + return torch.load(path, map_location="cpu", weights_only=True) diff --git a/src/training/__init__.py b/src/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/training/__pycache__/__init__.cpython-314.pyc b/src/training/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cce6502e6cd273582ba9176392a747545b1b434 Binary files /dev/null and b/src/training/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/training/__pycache__/callbacks.cpython-312.pyc b/src/training/__pycache__/callbacks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c01a2301495e779e50929922b4efa8c4eabce2b Binary files /dev/null and b/src/training/__pycache__/callbacks.cpython-312.pyc differ diff --git a/src/training/__pycache__/callbacks.cpython-314.pyc b/src/training/__pycache__/callbacks.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f7e121dc77ed4042d710e7c73fa9f31bd4f4af Binary files /dev/null and b/src/training/__pycache__/callbacks.cpython-314.pyc differ diff --git a/src/training/__pycache__/dataset.cpython-312.pyc b/src/training/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec7377c1d43b9802cff2be091008f88a26903f08 Binary files /dev/null and b/src/training/__pycache__/dataset.cpython-312.pyc differ diff --git a/src/training/__pycache__/dataset.cpython-314.pyc b/src/training/__pycache__/dataset.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c4660fd42e69d11eb5107d4aa12d751afcf3d68 Binary files /dev/null and b/src/training/__pycache__/dataset.cpython-314.pyc differ diff --git a/src/training/__pycache__/human_pattern_extractor.cpython-312.pyc b/src/training/__pycache__/human_pattern_extractor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b8cb81616a871b813941d279b1095f7217b3385 Binary files /dev/null and b/src/training/__pycache__/human_pattern_extractor.cpython-312.pyc differ diff --git a/src/training/__pycache__/human_pattern_extractor.cpython-314.pyc b/src/training/__pycache__/human_pattern_extractor.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1f56ba8519c8f5e113858d091d6a3ef139b5633 Binary files /dev/null and b/src/training/__pycache__/human_pattern_extractor.cpython-314.pyc differ diff --git a/src/training/__pycache__/loss_functions.cpython-312.pyc b/src/training/__pycache__/loss_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f14fbe64931aab7f26977b44d34205827315c7 Binary files /dev/null and b/src/training/__pycache__/loss_functions.cpython-312.pyc differ diff --git a/src/training/__pycache__/loss_functions.cpython-314.pyc b/src/training/__pycache__/loss_functions.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f618a89358af7b1a32d483b4130cdfbe83c758b Binary files /dev/null and b/src/training/__pycache__/loss_functions.cpython-314.pyc differ diff --git a/src/training/__pycache__/trainer.cpython-312.pyc b/src/training/__pycache__/trainer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0951c0567e346f41c00b6418006b369d0256fb3c Binary files /dev/null and b/src/training/__pycache__/trainer.cpython-312.pyc differ diff --git a/src/training/__pycache__/trainer.cpython-314.pyc b/src/training/__pycache__/trainer.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bdab94adb19c465882889024ad7cdd1bc4da549 Binary files /dev/null and b/src/training/__pycache__/trainer.cpython-314.pyc differ diff --git a/src/training/callbacks.py b/src/training/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..37805c43ce59eab971fcfc373a3a3c5228863569 --- /dev/null +++ b/src/training/callbacks.py @@ -0,0 +1,64 @@ +""" +Training callbacks for monitoring and checkpointing. +Integrates with Weights & Biases and TensorBoard. +""" + +from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments +from loguru import logger + +try: + import wandb + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + + +class StyleMetricsCallback(TrainerCallback): + """Logs style similarity metrics during evaluation.""" + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + metrics = kwargs.get("metrics", {}) + if metrics: + logger.info(f"Evaluation metrics at step {state.global_step}:") + for key, value in metrics.items(): + logger.info(f" {key}: {value:.4f}" if isinstance(value, float) else f" {key}: {value}") + + # Log to W&B if available + if HAS_WANDB and wandb.run is not None: + wandb.log( + {f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, + step=state.global_step, + ) + + +class EarlyStoppingOnStyleDrift(TrainerCallback): + """Stops training if style similarity drops below threshold.""" + + def __init__(self, min_style_similarity: float = 0.75): + self.min_style_similarity = min_style_similarity + self.best_style_sim = 0.0 + self.patience_counter = 0 + self.patience = 3 # Stop after 3 consecutive low evaluations + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + metrics = kwargs.get("metrics", {}) + style_sim = metrics.get("eval_style_similarity", None) + + if style_sim is not None: + if style_sim > self.best_style_sim: + self.best_style_sim = style_sim + self.patience_counter = 0 + + if style_sim < self.min_style_similarity: + self.patience_counter += 1 + logger.warning( + f"Style similarity {style_sim:.4f} below threshold {self.min_style_similarity}. " + f"Patience: {self.patience_counter}/{self.patience}" + ) + if self.patience_counter >= self.patience: + logger.error( + f"Early stopping: style similarity consistently below {self.min_style_similarity}" + ) + control.should_training_stop = True + else: + self.patience_counter = 0 diff --git a/src/training/dataset.py b/src/training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..74ab46d49bd83c4425366c3068171f1d3d21ccfa --- /dev/null +++ b/src/training/dataset.py @@ -0,0 +1,211 @@ +""" +Dataset class that handles all data sources and produces training triplets: +(input_text, style_vector, target_text) + +Data sources priority: +1. W&I+LOCNESS — real learner errors with expert corrections +2. JFLEG — naturalistic fluency corrections +3. GYAFC — informal→formal style transfer +4. Synthetic — dyslexia simulator augmentation on Wikipedia/books +5. Custom — any user-provided correction pairs + +OPTIMISATION: Everything is pre-computed at init and cached to disk: + - Tokenisation (input_ids, attention_mask, labels) + - Style vectors (spaCy + MLP) + - Disk cache at data/cache/.pt — skips re-computation on re-runs + __getitem__ is a pure dict return — zero computation per batch. +""" + +import json +import os +from pathlib import Path +from typing import List, Dict, Optional +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer +from ..style.fingerprinter import StyleFingerprinter +from ..preprocessing.dyslexia_simulator import DyslexiaSimulator +from loguru import logger +import random +import hashlib + + +TASK_PREFIX = ( + "Correct the following text for grammar, spelling, and clarity. " + "Maintain the author's original tone and writing style. " + "Elevate vocabulary to academic register. " + "Do NOT change the meaning or add new information. " + "Preserve named entities exactly. " + "Text to correct: " +) + +CACHE_DIR = Path("data/cache") + + +class WritingCorrectionDataset(Dataset): + """PyTorch dataset for writing correction training triplets. + + Fully pre-computed at init with disk caching: + - First run: tokenises + extracts style vectors (~10 min), saves to disk + - Subsequent runs: loads from disk cache (~5 seconds) + - __getitem__ is a pure dict return (zero computation) + """ + + def __init__( + self, + data_path: str, + tokenizer: PreTrainedTokenizer, + fingerprinter: StyleFingerprinter, + max_input_length: int = 256, + max_target_length: int = 256, + augment_with_synthetic: bool = True, + synthetic_ratio: float = 0.3, + ): + self.tokenizer = tokenizer + self.fingerprinter = fingerprinter + self.max_input_length = max_input_length + self.max_target_length = max_target_length + + # Load data + self.examples = self._load(data_path) + logger.info(f"Loaded {len(self.examples)} examples from {data_path}") + + # Augment with synthetic dyslexia data + if augment_with_synthetic and self.examples: + self._add_synthetic(synthetic_ratio) + + logger.info(f"Total dataset size: {len(self.examples)} examples") + + # Compute cache key from data content + config + cache_key = self._compute_cache_key(data_path, augment_with_synthetic, synthetic_ratio) + cache_path = CACHE_DIR / f"{cache_key}.pt" + + # Try loading from disk cache + if cache_path.exists(): + logger.info(f"Loading pre-computed dataset from cache: {cache_path}") + self._precomputed = torch.load(cache_path, map_location="cpu", weights_only=False) + logger.info(f"Loaded {len(self._precomputed)} cached examples") + else: + # Pre-compute everything and save to disk + self._precomputed = self._precompute_all() + CACHE_DIR.mkdir(parents=True, exist_ok=True) + torch.save(self._precomputed, cache_path) + logger.info(f"Saved pre-computed dataset to cache: {cache_path}") + + def _compute_cache_key(self, data_path: str, augment: bool, ratio: float) -> str: + """Generate a cache key based on data file content and processing params.""" + h = hashlib.md5() + # Hash the data file content + try: + h.update(Path(data_path).read_bytes()) + except FileNotFoundError: + h.update(data_path.encode()) + # Hash processing parameters + h.update(f"aug={augment}|ratio={ratio}|maxin={self.max_input_length}|maxtgt={self.max_target_length}".encode()) + return h.hexdigest()[:16] + + def _load(self, path: str) -> List[Dict]: + """Load JSONL data file.""" + examples = [] + try: + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + if "input" in obj and "target" in obj: + examples.append(obj) + except json.JSONDecodeError: + continue + except FileNotFoundError: + logger.warning(f"Data file not found: {path}") + return examples + + def _add_synthetic(self, ratio: float): + """Augment dataset with synthetic dyslexia examples.""" + simulator = DyslexiaSimulator(error_rate=0.15, seed=42) + num_synthetic = int(len(self.examples) * ratio) + + # Sample target texts to corrupt + source_examples = random.Random(42).choices(self.examples, k=num_synthetic) + + synthetic_count = 0 + for example in source_examples: + target = example["target"] + corrupted, clean = simulator.simulate(target) + + # Only add if corruption actually changed the text + if corrupted != clean: + self.examples.append({ + "input": corrupted, + "target": clean, + "source": "synthetic", + }) + synthetic_count += 1 + + logger.info(f"Added {synthetic_count} synthetic augmentation examples") + + def _precompute_all(self) -> List[Dict[str, torch.Tensor]]: + """Pre-compute tokenisation + style vectors for ALL examples. + This makes __getitem__ a pure dict return with zero computation. + """ + logger.info("Pre-computing tokenisation and style vectors for all examples...") + precomputed = [] + style_cache = {} # Deduplicate identical target texts + + for i, example in enumerate(self.examples): + input_text = TASK_PREFIX + example["input"] + target_text = example["target"] + + # Tokenise input + input_encoding = self.tokenizer( + input_text, + max_length=self.max_input_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + + # Tokenise target (labels) + target_encoding = self.tokenizer( + target_text, + max_length=self.max_target_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + + # Style vector (cached by content hash) + cache_key = hashlib.md5(target_text.encode()).hexdigest()[:16] + if cache_key not in style_cache: + with torch.no_grad(): + style_cache[cache_key] = self.fingerprinter.extract_vector(target_text) + style_vector = style_cache[cache_key] + + # Labels — set padding tokens to -100 so they're ignored in loss + labels = target_encoding["input_ids"].squeeze() + labels[labels == self.tokenizer.pad_token_id] = -100 + + precomputed.append({ + "input_ids": input_encoding["input_ids"].squeeze(), + "attention_mask": input_encoding["attention_mask"].squeeze(), + "labels": labels, + "style_vector": style_vector, + "input_text": example["input"], + "target_text": target_text, + }) + + if (i + 1) % 2000 == 0: + logger.info(f" Pre-computed: {i + 1}/{len(self.examples)}") + + logger.info(f"Pre-computation complete ({len(style_cache)} unique style vectors)") + return precomputed + + def __len__(self): + return len(self._precomputed) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Pure dict return — zero computation per batch.""" + return self._precomputed[idx] diff --git a/src/training/human_pattern_extractor.py b/src/training/human_pattern_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b8e3b630aecc2a0a048bced33f1677cb1a693a --- /dev/null +++ b/src/training/human_pattern_extractor.py @@ -0,0 +1,542 @@ +""" +Extracts the statistical signature of human writing vs AI writing. +Uses Kaggle datasets to build: + +1. HumanPatternProfile — a statistical distribution of human writing features +2. AIPatternProfile — a statistical distribution of AI writing features +3. HumanPatternClassifier — a lightweight FROZEN classifier used at training time + to score how "human-like" the model's output looks. + +The classifier is FROZEN during main model training. It is pre-trained separately +on the Kaggle datasets, then its output score is used as a reward/penalty signal +in the main training loss. + +Feature set extracted (17 dimensions): + - Perplexity under GPT-2 (AI text tends to be lower perplexity) + - Burstiness score (human writing has more sentence length variance) + - Sentence starter diversity + - n-gram novelty scores (bigram, trigram, 4-gram) + - AI marker density + - Overused discourse density + - Punctuation patterns (em-dash, ellipsis, comma, semicolon rates) + - Distributional features (word count, sentence count, mean/std sent length, TTR) +""" + +import pandas as pd +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from transformers import GPT2LMHeadModel, GPT2TokenizerFast +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from typing import List, Tuple, Dict, Optional +import spacy +from collections import Counter +import math +from loguru import logger +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp + + +# ── AI-Typical Overused Discourse Markers ─────────────────────────────────── +AI_OVERUSED_MARKERS = { + "furthermore", "moreover", "additionally", "consequently", + "in conclusion", "to summarize", "it is worth noting", + "it is important to note", "in today's world", "in today's society", + "in the modern era", "as previously mentioned", "needless to say", + "it goes without saying", "at the end of the day", + "in terms of", "with regard to", "with respect to", + "delve", "leverage", "utilize", "holistic", "paradigm", + "transformative", "groundbreaking", "revolutionary", "game-changing", + "multifaceted", "nuanced", "comprehensive", "robust", "seamless", + "innovative", "synergy", "cutting-edge", "state-of-the-art", +} + +# Words that AI uses far MORE than humans in academic-adjacent writing +AI_FINGERPRINT_WORDS = { + "delve", "underscore", "tapestry", "intricate", "pivotal", + "crucial", "vital", "essential", "significant", "notable", + "commendable", "noteworthy", "straightforward", "straightforwardly", + "elucidate", "expound", "illuminate", "unravel", "harness", + "foster", "facilitate", "leverage", "optimize", "streamline", +} + + +# ── Standalone text-feature functions (picklable for multiprocessing) ─────── +def _compute_text_features(text: str) -> np.ndarray: + """Compute the 16 non-perplexity features from raw text. + Returns a 16-dim float32 array (features 2-17, perplexity slot excluded). + This function is designed to be called in a worker process. + """ + if not text or not text.strip(): + return np.zeros(16, dtype=np.float32) + + words = text.split() + word_count = max(len(words), 1) + + # Cheap sentence splitting (regex-based, avoids loading spaCy per worker) + import re + raw_sents = re.split(r'(?<=[.!?])\s+', text.strip()) + sentences = [s.strip() for s in raw_sents if s.strip()] + sent_lengths = [len(s.split()) for s in sentences] if sentences else [0] + + features = [] + + # 1. Burstiness + if len(sentences) < 2: + features.append(0.0) + else: + lengths = [len(s.split()) for s in sentences] + mean_len = np.mean(lengths) + features.append(float(np.std(lengths) / mean_len) if mean_len > 0 else 0.0) + + # 2. Sentence starter diversity + if not sentences: + features.append(0.0) + else: + starters = [] + for s in sentences: + w = s.strip().split() + if w: + starters.append(w[0].lower()) + features.append(len(set(starters)) / len(starters) if starters else 0.0) + + # 3-5. N-gram novelty (bigram, trigram, 4-gram) + words_lower = text.lower().split() + for n in (2, 3, 4): + if len(words_lower) < n: + features.append(1.0) + else: + ngrams = [tuple(words_lower[i:i + n]) for i in range(len(words_lower) - n + 1)] + features.append(len(set(ngrams)) / len(ngrams) if ngrams else 1.0) + + # 6. AI marker density + word_set = set(text.lower().split()) + ai_count = len(word_set & AI_FINGERPRINT_WORDS) + features.append((ai_count / word_count) * 100) + + # 7. Overused discourse density + text_lower = text.lower() + discourse_count = sum(1 for marker in AI_OVERUSED_MARKERS if marker in text_lower) + features.append((discourse_count / word_count) * 100) + + # 8-11. Punctuation patterns + features.append((text.count("—") + text.count("–")) / word_count * 100) # em-dash + features.append(text.count("...") / word_count * 100) # ellipsis + features.append(text.count(",") / word_count * 100) # comma + features.append(text.count(";") / word_count * 100) # semicolon + + # 12. Word count (log-scaled) + features.append(np.log1p(word_count)) + + # 13. Sentence count (log-scaled) + features.append(np.log1p(len(sentences))) + + # 14. Mean sentence length + features.append(np.mean(sent_lengths)) + + # 15. Std sentence length + features.append(np.std(sent_lengths) if len(sent_lengths) > 1 else 0.0) + + # 16. Type-token ratio + unique_words = set(w.lower() for w in words) + features.append(len(unique_words) / word_count) + + return np.array(features, dtype=np.float32) + + +class HumanPatternFeatureExtractor: + """Extracts 17-dimensional feature vector encoding human vs AI writing patterns. + + Optimised for bulk extraction: + - GPT-2 perplexity computed in batches on GPU (if available) + - Text features computed in parallel via multiprocessing + """ + + def __init__(self, spacy_model: str = "en_core_web_sm", device: Optional[str] = None): + # Determine device + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + # GPT-2 for perplexity calculation + logger.info("Loading GPT-2 for perplexity calculation...") + self.gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2") + self.gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token + self.gpt2_model.eval() + + # Move to best available device + self.gpt2_model = self.gpt2_model.to(self.device) + + # Use half precision on GPU for speed + if self.device == "cuda": + self.gpt2_model = self.gpt2_model.half() + logger.info(f"GPT-2 loaded on {self.device} with fp16") + else: + logger.info(f"GPT-2 loaded on {self.device}") + + logger.info("HumanPatternFeatureExtractor initialised") + + def _perplexity(self, text: str, max_len: int = 256) -> float: + """GPT-2 perplexity for a single text. Lower = more AI-like.""" + try: + encodings = self.gpt2_tokenizer( + text, return_tensors="pt", truncation=True, max_length=max_len + ) + input_ids = encodings["input_ids"].to(self.device) + + if input_ids.size(1) < 2: + return 100.0 # Default for very short text + + with torch.no_grad(): + outputs = self.gpt2_model(input_ids, labels=input_ids) + loss = outputs.loss + + return math.exp(min(loss.float().item(), 10)) # Cap to avoid inf + except Exception: + return 100.0 # Safe default + + def _perplexity_batch(self, texts: List[str], max_len: int = 256, batch_size: int = 8) -> List[float]: + """Compute GPT-2 perplexity for a batch of texts efficiently on GPU. + + Processes texts in mini-batches with padding for maximum throughput. + Default batch_size=8 sized for GPUs with ~4GB VRAM (e.g. RTX 3050). + """ + results = [] + + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + + # Tokenise with padding + encodings = self.gpt2_tokenizer( + batch_texts, + return_tensors="pt", + truncation=True, + max_length=max_len, + padding=True, + ) + + input_ids = encodings["input_ids"].to(self.device) + attention_mask = encodings["attention_mask"].to(self.device) + + with torch.no_grad(), torch.amp.autocast(device_type=self.device if self.device != "cpu" else "cpu"): + # Forward pass for the whole batch + outputs = self.gpt2_model( + input_ids, + attention_mask=attention_mask, + ) + logits = outputs.logits + + # Compute per-sample perplexity from logits + # Shift logits and labels for causal LM loss + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = input_ids[:, 1:].contiguous() + shift_mask = attention_mask[:, 1:].contiguous() + + # Per-token cross entropy (no reduction) + loss_fct = nn.CrossEntropyLoss(reduction="none") + # Reshape for loss computation + per_token_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ).view(shift_labels.size()) + + # Mask out padding tokens and compute mean per sample + masked_loss = per_token_loss * shift_mask.float() + token_counts = shift_mask.float().sum(dim=1).clamp(min=1) + per_sample_loss = masked_loss.sum(dim=1) / token_counts + + # Convert to perplexity + for loss_val in per_sample_loss: + ppl = math.exp(min(loss_val.float().item(), 10)) + results.append(ppl) + + # Free GPU memory between batches (critical for low-VRAM GPUs) + del input_ids, attention_mask, outputs, logits, shift_logits, shift_labels + del shift_mask, per_token_loss, masked_loss, token_counts, per_sample_loss + if self.device == "cuda": + torch.cuda.empty_cache() + + return results + + def extract(self, text: str) -> np.ndarray: + """Extract full 17-dimensional feature vector for a single text.""" + if not text or not text.strip(): + return np.zeros(17, dtype=np.float32) + + # Perplexity (feature 1) + ppl = self._perplexity(text) + + # All other features (features 2-17) + text_features = _compute_text_features(text) + + # Combine: [perplexity, ...16 text features] + features = np.empty(17, dtype=np.float32) + features[0] = ppl + features[1:] = text_features + + return features + + def extract_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + num_workers: int = 0, + progress_every: int = 1000, + ) -> np.ndarray: + """Extract features for many texts efficiently. + + Strategy: + 1. Compute perplexity in batched GPU forward passes + 2. Compute text features in parallel via multiprocessing + 3. Merge into (N, 17) array + + Args: + texts: List of text strings + batch_size: Batch size for GPT-2 perplexity (default 8 for ~4GB VRAM GPUs) + num_workers: Number of processes for text features. 0 = auto-detect. + progress_every: Log progress every N texts + + Returns: + np.ndarray of shape (len(texts), 17) + """ + n = len(texts) + if batch_size is None: + # Auto-size: 8 for 4GB VRAM, 16 for 8GB, 32 for 16GB+ + if self.device == "cuda": + vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) + batch_size = max(4, min(32, int(vram_gb))) + else: + batch_size = 4 + logger.info(f"Extracting features for {n} texts (device={self.device}, batch_size={batch_size})") + + # ── Step 1: Batched perplexity on GPU ────────────────────────────── + logger.info(" Computing batched GPT-2 perplexity...") + all_ppl = [] + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + batch = texts[start:end] + ppl_batch = self._perplexity_batch(batch, batch_size=len(batch)) + all_ppl.extend(ppl_batch) + + if (start // batch_size) % max(1, (progress_every // batch_size)) == 0 and start > 0: + logger.info(f" Perplexity: {start}/{n}") + + logger.info(f" Perplexity complete: {n}/{n}") + + # ── Step 2: Text features in parallel ────────────────────────────── + logger.info(" Computing text features (parallel)...") + if num_workers == 0: + num_workers = min(mp.cpu_count(), 8) + + # For small datasets or if multiprocessing causes issues, fall back to serial + if n < 500 or num_workers <= 1: + text_features_list = [] + for i, text in enumerate(texts): + text_features_list.append(_compute_text_features(text)) + if i > 0 and i % progress_every == 0: + logger.info(f" Text features: {i}/{n}") + else: + # Use ProcessPoolExecutor for CPU-bound text feature extraction + text_features_list = [] + with ProcessPoolExecutor(max_workers=num_workers) as executor: + # Submit in chunks for better progress tracking + chunk_size = 2000 + for chunk_start in range(0, n, chunk_size): + chunk_end = min(chunk_start + chunk_size, n) + chunk = texts[chunk_start:chunk_end] + chunk_results = list(executor.map(_compute_text_features, chunk, chunksize=200)) + text_features_list.extend(chunk_results) + if chunk_start > 0: + logger.info(f" Text features: {chunk_start}/{n}") + + logger.info(f" Text features complete: {n}/{n}") + + # ── Step 3: Merge ────────────────────────────────────────────────── + features = np.empty((n, 17), dtype=np.float32) + features[:, 0] = np.array(all_ppl, dtype=np.float32) + features[:, 1:] = np.array(text_features_list, dtype=np.float32) + + return features + + +class KaggleHumanPatternDataset(Dataset): + """ + Loads both Kaggle datasets and produces (feature_vector, label) pairs. + label = 1 (human) | 0 (AI) + """ + + def __init__( + self, + shanegerami_path: str, + starblasters_path: str, + extractor: HumanPatternFeatureExtractor, + max_samples_per_source: int = 50000, + ): + self.extractor = extractor + self.texts = [] + self.labels = [] + + # Load Shanegerami AI_Human.csv + logger.info(f"Loading Shanegerami dataset from {shanegerami_path}...") + try: + df_shane = pd.read_csv(shanegerami_path, nrows=max_samples_per_source * 2) + # Auto-detect column names + text_col = None + label_col = None + for col in df_shane.columns: + col_lower = col.lower() + if col_lower in ("text", "essay_text", "content", "essay"): + text_col = col + elif col_lower in ("generated", "label", "is_ai", "ai_generated", "class"): + label_col = col + + if text_col is None: + text_col = df_shane.columns[0] + logger.warning(f"Auto-detected text column: {text_col}") + if label_col is None: + label_col = df_shane.columns[-1] + logger.warning(f"Auto-detected label column: {label_col}") + + # Sample balanced dataset + human_mask = df_shane[label_col] == 0 + ai_mask = df_shane[label_col] == 1 + + human_texts = df_shane.loc[human_mask, text_col].dropna().head(max_samples_per_source).tolist() + ai_texts = df_shane.loc[ai_mask, text_col].dropna().head(max_samples_per_source).tolist() + + self.texts.extend(human_texts) + self.labels.extend([1] * len(human_texts)) # 1 = human + self.texts.extend(ai_texts) + self.labels.extend([0] * len(ai_texts)) # 0 = AI + + logger.info(f"Shanegerami: {len(human_texts)} human + {len(ai_texts)} AI samples") + except Exception as e: + logger.warning(f"Failed to load Shanegerami dataset: {e}") + + # Load Starblasters8 data.parquet + logger.info(f"Loading Starblasters8 dataset from {starblasters_path}...") + try: + df_star = pd.read_parquet(starblasters_path) + + # Auto-detect columns + text_col = None + label_col = None + for col in df_star.columns: + col_lower = col.lower() + if col_lower in ("text", "essay_text", "content", "essay"): + text_col = col + elif col_lower in ("generated", "label", "is_ai", "ai_generated", "source"): + label_col = col + + if text_col is None: + text_col = df_star.columns[0] + if label_col is None: + label_col = df_star.columns[-1] + + human_mask = df_star[label_col] == 0 + ai_mask = df_star[label_col] == 1 + + human_texts = df_star.loc[human_mask, text_col].dropna().head(max_samples_per_source).tolist() + ai_texts = df_star.loc[ai_mask, text_col].dropna().head(max_samples_per_source).tolist() + + self.texts.extend(human_texts) + self.labels.extend([1] * len(human_texts)) + self.texts.extend(ai_texts) + self.labels.extend([0] * len(ai_texts)) + + logger.info(f"Starblasters8: {len(human_texts)} human + {len(ai_texts)} AI samples") + except Exception as e: + logger.warning(f"Failed to load Starblasters8 dataset: {e}") + + logger.info(f"Total dataset size: {len(self.texts)} samples") + + # Pre-extract features for all texts (cached for training speed) + self._features = None + self._precomputed = False + + def precompute_features(self): + """Pre-compute all features using optimised batched extraction.""" + if self._precomputed: + return + + logger.info("Pre-computing features for all texts...") + + # Truncate very long texts for speed + truncated_texts = [ + str(text)[:2000] if len(str(text)) > 2000 else str(text) + for text in self.texts + ] + + # Use the fast batched extraction path + features_array = self.extractor.extract_batch( + truncated_texts, + batch_size=None, # Auto-detect based on VRAM + num_workers=0, # Auto-detect CPU count + progress_every=2000, + ) + + # Store as list of arrays for compatibility with __getitem__ + self._features = [features_array[i] for i in range(len(features_array))] + self._precomputed = True + logger.info("Feature pre-computation complete") + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: + if self._precomputed and self._features is not None: + features = self._features[idx] + else: + text = str(self.texts[idx])[:2000] + features = self.extractor.extract(text) + + features_tensor = torch.tensor(features, dtype=torch.float32) + + # Handle NaN/Inf values that can occur from edge cases + features_tensor = torch.nan_to_num(features_tensor, nan=0.0, posinf=10.0, neginf=-10.0) + + return features_tensor, self.labels[idx] + + +class HumanPatternClassifier(nn.Module): + """ + Lightweight MLP trained to distinguish human from AI writing. + Input: feature vector from HumanPatternFeatureExtractor + Output: probability that text is human-written (0 to 1) + + PRE-TRAINED on Kaggle datasets, then FROZEN during main training. + """ + + def __init__(self, input_dim: int = 17, hidden_dim: int = 128): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.BatchNorm1d(hidden_dim // 2), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(hidden_dim // 2, 1), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Returns human-likeness score in [0, 1]. Higher = more human.""" + logits = self.net(features) + return torch.sigmoid(logits).squeeze(-1) + + def score(self, text: str, extractor: HumanPatternFeatureExtractor) -> float: + """Convenience: score a single text string.""" + self.eval() + features = extractor.extract(text) + features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) + features_tensor = torch.nan_to_num(features_tensor, nan=0.0, posinf=10.0, neginf=-10.0) + with torch.no_grad(): + score = self.forward(features_tensor) + return score.item() diff --git a/src/training/loss_functions.py b/src/training/loss_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8a39db514aeb0fdaeefb186b83fb742da6464a --- /dev/null +++ b/src/training/loss_functions.py @@ -0,0 +1,230 @@ +""" +Combined training loss with Human-Pattern Term: + +L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern + +Where: + L_CE = cross-entropy language model loss (standard token prediction) + L_style = style consistency loss (cosine distance between output and target style vectors) + L_semantic = semantic similarity loss (cosine distance between sentence embeddings) + L_human_pattern = 1 - HumanPatternClassifier.score(output_text) + λ₁ = style loss weight (default 0.3) + λ₂ = semantic loss weight (default 0.5) + λ₃ = human pattern weight (default 0.4) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sentence_transformers import SentenceTransformer +from typing import Optional, List, Dict +from loguru import logger + + +class CombinedCorrectionLoss(nn.Module): + """V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic.""" + + def __init__( + self, + lambda_style: float = 0.3, + lambda_semantic: float = 0.5, + sem_model_name: str = "all-mpnet-base-v2", + device: str = "cpu", + ): + super().__init__() + self.lambda_style = lambda_style + self.lambda_semantic = lambda_semantic + self.device = device + + # Cross-entropy loss + self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) + + # Frozen sentence transformer for semantic similarity + logger.info(f"Loading sentence transformer for loss: {sem_model_name}") + self.sem_model = SentenceTransformer(sem_model_name, device=device) + self.sem_model.eval() + # Freeze sentence transformer weights + for param in self.sem_model.parameters(): + param.requires_grad = False + + def _style_loss( + self, + output_style_vec: torch.Tensor, + target_style_vec: torch.Tensor, + ) -> torch.Tensor: + """1 - cosine_similarity(output_style, target_style).""" + if output_style_vec.dim() == 1: + output_style_vec = output_style_vec.unsqueeze(0) + if target_style_vec.dim() == 1: + target_style_vec = target_style_vec.unsqueeze(0) + cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) + return (1.0 - cos_sim).mean() + + def _semantic_loss( + self, + input_texts: List[str], + output_texts: List[str], + ) -> torch.Tensor: + """Penalises meaning change between input and output.""" + with torch.no_grad(): + input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True) + output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True) + + cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1) + # Loss = 1 - similarity (we want high similarity = low loss) + return (1.0 - cos_sim).mean() + + def forward( + self, + logits: torch.Tensor, + labels: torch.Tensor, + output_style_vec: Optional[torch.Tensor] = None, + target_style_vec: Optional[torch.Tensor] = None, + input_texts: Optional[List[str]] = None, + output_texts: Optional[List[str]] = None, + ) -> Dict[str, torch.Tensor]: + """Compute combined loss.""" + losses = {} + + # L_CE: cross-entropy + # logits: [batch, seq_len, vocab_size] + # labels: [batch, seq_len] + if logits.dim() == 3: + ce_logits = logits.view(-1, logits.size(-1)) + ce_labels = labels.view(-1) + else: + ce_logits = logits + ce_labels = labels + l_ce = self.ce_loss(ce_logits, ce_labels) + losses["ce_loss"] = l_ce + + total = l_ce + + # L_style + if output_style_vec is not None and target_style_vec is not None: + l_style = self._style_loss(output_style_vec, target_style_vec) + losses["style_loss"] = l_style + total = total + self.lambda_style * l_style + + # L_semantic + if input_texts is not None and output_texts is not None: + l_semantic = self._semantic_loss(input_texts, output_texts) + losses["semantic_loss"] = l_semantic + total = total + self.lambda_semantic * l_semantic + + losses["total_loss"] = total + return losses + + +class CombinedCorrectionLossV2(nn.Module): + """V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern.""" + + def __init__( + self, + lambda_style: float = 0.3, + lambda_semantic: float = 0.5, + lambda_human_pattern: float = 0.4, + classifier_path: str = "checkpoints/human_pattern_classifier.pt", + sem_model_name: str = "all-mpnet-base-v2", + device: str = "cpu", + ): + super().__init__() + self.lambda_style = lambda_style + self.lambda_semantic = lambda_semantic + self.lambda_human_pattern = lambda_human_pattern + self.device = device + + # V1 components + self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) + + # Sentence transformer on CPU to save GPU VRAM for main model + logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)") + self.sem_model = SentenceTransformer(sem_model_name, device="cpu") + self.sem_model.eval() + + # Load frozen human pattern classifier + from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor + self.hp_classifier = HumanPatternClassifier() + try: + state_dict = torch.load(classifier_path, map_location=device, weights_only=True) + self.hp_classifier.load_state_dict(state_dict) + logger.info(f"Loaded human pattern classifier from {classifier_path}") + except FileNotFoundError: + logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights") + + self.hp_classifier.eval() + for param in self.hp_classifier.parameters(): + param.requires_grad = False + + # Feature extractor on CPU to save GPU VRAM for main model + self.hp_extractor = HumanPatternFeatureExtractor(device="cpu") + + def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor: + """Loss = 1 - human_score. Penalise AI-like outputs.""" + scores = [] + for text in output_texts: + score = self.hp_classifier.score(text, self.hp_extractor) + scores.append(score) + device = compute_device if compute_device is not None else self.device + human_scores = torch.tensor(scores, dtype=torch.float32, device=device) + return (1.0 - human_scores).mean() + + def forward( + self, + logits: torch.Tensor, + labels: torch.Tensor, + output_style_vec: Optional[torch.Tensor] = None, + target_style_vec: Optional[torch.Tensor] = None, + input_texts: Optional[List[str]] = None, + output_texts: Optional[List[str]] = None, + ) -> Dict[str, torch.Tensor]: + """Compute combined loss with human pattern term.""" + losses = {} + + # L_CE + if logits.dim() == 3: + ce_logits = logits.view(-1, logits.size(-1)) + ce_labels = labels.view(-1) + else: + ce_logits = logits + ce_labels = labels + l_ce = self.ce_loss(ce_logits, ce_labels) + losses["ce_loss"] = l_ce + total = l_ce + + # L_style + if output_style_vec is not None and target_style_vec is not None: + # Ensure both vectors are on the same device (style vecs may come from CPU fingerprinter) + compute_device = logits.device + output_style_vec = output_style_vec.to(compute_device) + target_style_vec = target_style_vec.to(compute_device) + if output_style_vec.dim() == 1: + output_style_vec = output_style_vec.unsqueeze(0) + if target_style_vec.dim() == 1: + target_style_vec = target_style_vec.unsqueeze(0) + cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) + l_style = (1.0 - cos_sim).mean() + losses["style_loss"] = l_style + total = total + self.lambda_style * l_style + + # L_semantic + if input_texts is not None and output_texts is not None: + with torch.no_grad(): + input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True) + output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True) + # sem_model is on CPU, move embeddings to compute device + input_emb = input_emb.to(logits.device) + output_emb = output_emb.to(logits.device) + cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1) + l_semantic = (1.0 - cos_sim).mean() + losses["semantic_loss"] = l_semantic + total = total + self.lambda_semantic * l_semantic + + # L_human_pattern + if output_texts is not None: + l_human = self._human_pattern_loss(output_texts, compute_device=logits.device) + losses["human_pattern_loss"] = l_human + total = total + self.lambda_human_pattern * l_human + + losses["total_loss"] = total + return losses diff --git a/src/training/trainer.py b/src/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c80ee2510fb8d9dc084a9e6edf1e41ac0309bc --- /dev/null +++ b/src/training/trainer.py @@ -0,0 +1,53 @@ +""" +Custom HuggingFace Trainer subclass. +Uses the model's built-in cross-entropy loss (computed during forward pass) +instead of recomputing it, saving ~60MB of VRAM. +""" + +from transformers import Trainer +import torch +from loguru import logger + + +class CorrectionTrainer(Trainer): + """Custom trainer — uses model's built-in loss directly.""" + + def __init__(self, loss_fn, fingerprinter, tokenizer, **kwargs): + super().__init__(**kwargs) + self.loss_fn = loss_fn # Kept for API compat, not actually used + self.fingerprinter = fingerprinter + self.correction_tokenizer = tokenizer + + def _strip_custom_fields(self, inputs): + """Remove dataset fields that T5 doesn't accept.""" + inputs.pop("style_vector", None) + inputs.pop("input_text", None) + inputs.pop("target_text", None) + return {k: v for k, v in inputs.items() if k in ("input_ids", "attention_mask", "labels")} + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + """Use model's built-in CE loss — avoids double-computing logits loss.""" + model_inputs = self._strip_custom_fields(inputs) + + outputs = model(**model_inputs) + # T5 computes CE loss internally when labels are provided — use it directly + # This avoids keeping the full logits tensor (batch × seq × 32128) alive + loss = outputs.loss + + return (loss, outputs) if return_outputs else loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + """Compute eval loss directly — strips custom fields and runs forward. + + The parent's prediction_step doesn't return eval_loss when custom + fields are present, so we handle it ourselves. + """ + model_inputs = self._strip_custom_fields(inputs) + model_inputs = self._prepare_inputs(model_inputs) + + with torch.no_grad(): + outputs = model(**model_inputs) + loss = outputs.loss.detach() + + return (loss, None, None) + diff --git a/src/vocabulary/__init__.py b/src/vocabulary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/vocabulary/__pycache__/__init__.cpython-314.pyc b/src/vocabulary/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eee974ee3db7491bca1c63611049f7792ec15c3 Binary files /dev/null and b/src/vocabulary/__pycache__/__init__.cpython-314.pyc differ diff --git a/src/vocabulary/__pycache__/awl_loader.cpython-312.pyc b/src/vocabulary/__pycache__/awl_loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8890e3c1e94d898087ae9691d69f16ae7e7da857 Binary files /dev/null and b/src/vocabulary/__pycache__/awl_loader.cpython-312.pyc differ diff --git a/src/vocabulary/__pycache__/awl_loader.cpython-314.pyc b/src/vocabulary/__pycache__/awl_loader.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b73eac6f3c9de47d6c2c35f7e3af4352d93feb9 Binary files /dev/null and b/src/vocabulary/__pycache__/awl_loader.cpython-314.pyc differ diff --git a/src/vocabulary/__pycache__/lexical_substitution.cpython-312.pyc b/src/vocabulary/__pycache__/lexical_substitution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059304183da3f40e841f594b4b1157c047e58cf4 Binary files /dev/null and b/src/vocabulary/__pycache__/lexical_substitution.cpython-312.pyc differ diff --git a/src/vocabulary/__pycache__/lexical_substitution.cpython-314.pyc b/src/vocabulary/__pycache__/lexical_substitution.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cabe52243530aec5240d905ccbd485fc37c8c74 Binary files /dev/null and b/src/vocabulary/__pycache__/lexical_substitution.cpython-314.pyc differ diff --git a/src/vocabulary/__pycache__/register_filter.cpython-312.pyc b/src/vocabulary/__pycache__/register_filter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c90e4148e5deebb0d06d8fbfc72922d8684ba1 Binary files /dev/null and b/src/vocabulary/__pycache__/register_filter.cpython-312.pyc differ diff --git a/src/vocabulary/awl_loader.py b/src/vocabulary/awl_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..41bb8695c83746992f448cbcacd2641305994a95 --- /dev/null +++ b/src/vocabulary/awl_loader.py @@ -0,0 +1,77 @@ +""" +Academic Word List (AWL) loader. +Loads the Coxhead AWL and supplementary domain-specific lexicons. +Provides lookup methods for checking if a word is in the academic register. +""" + +from pathlib import Path +from typing import Set, List, Optional +import json +from loguru import logger + + +class AWLLoader: + """Loads and manages Academic Word List data.""" + + def __init__( + self, + primary_path: str = "data/awl/coxhead_awl.txt", + supplementary_paths: Optional[List[str]] = None, + synonyms_path: Optional[str] = "data/awl/academic_synonyms.json", + ): + # Load primary AWL + self._words: Set[str] = self._load_word_list(primary_path) + logger.info(f"Loaded {len(self._words)} primary AWL words from {primary_path}") + + # Load supplementary domain lexicons + if supplementary_paths: + for path in supplementary_paths: + supplement = self._load_word_list(path) + self._words |= supplement + logger.info(f"Added {len(supplement)} words from {path}") + + # Load synonym mappings + self._synonyms: dict = {} + if synonyms_path: + self._synonyms = self._load_synonyms(synonyms_path) + logger.info(f"Loaded {len(self._synonyms)} synonym mappings") + + def _load_word_list(self, path: str) -> Set[str]: + """Load a word list file into a set of lowercase words.""" + words = set() + try: + with open(path) as f: + for line in f: + word = line.strip().lower() + if word and not word.startswith("#"): + words.add(word) + except FileNotFoundError: + logger.warning(f"Word list file not found: {path}") + return words + + def _load_synonyms(self, path: str) -> dict: + """Load academic synonym mappings from JSON.""" + try: + with open(path) as f: + data = json.load(f) + # Normalise keys to lowercase + return {k.lower(): [s.lower() for s in v] for k, v in data.items()} + except FileNotFoundError: + logger.warning(f"Synonyms file not found: {path}") + return {} + except json.JSONDecodeError: + logger.warning(f"Invalid JSON in synonyms file: {path}") + return {} + + def is_academic(self, word: str) -> bool: + """Check if a word (or its lemma) is in the AWL.""" + return word.lower().strip() in self._words + + def get_academic_synonyms(self, word: str) -> List[str]: + """Return academic synonyms for a colloquial word.""" + return self._synonyms.get(word.lower().strip(), []) + + @property + def all_words(self) -> Set[str]: + """Return the full set of academic words.""" + return self._words.copy() diff --git a/src/vocabulary/lexical_substitution.py b/src/vocabulary/lexical_substitution.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1cfca56e49ab36a5c1e7e6d8780b335960c7bc --- /dev/null +++ b/src/vocabulary/lexical_substitution.py @@ -0,0 +1,246 @@ +""" +Post-generation academic vocabulary elevation module. + +Pipeline: +1. POS-tag the generated output +2. Identify content words (NOUN, VERB, ADJ, ADV) NOT in AWL +3. For each candidate word, generate AWL-aligned substitutions + using BERT masked language model (fill-mask) +4. Apply substitution only if: + a. Semantic similarity between original and substitution > threshold + b. Substitution is in the AWL + c. Substitution does not change sentence meaning +5. Apply register-level post-processing (nominalisation, hedging, passive) + +AWL = Coxhead Academic Word List (570 word families, ~3,000 lemmas) +""" + +import spacy +import torch +from transformers import pipeline as hf_pipeline +from sentence_transformers import SentenceTransformer +import torch.nn.functional as F +from typing import List, Dict, Tuple, Optional +from .awl_loader import AWLLoader +from loguru import logger +import re + + +class LexicalElevator: + """Elevates vocabulary to academic register using BERT-based substitution.""" + + # Words that should NEVER be substituted (structural, functional words) + PROTECTED_POS = {"PRON", "DET", "CCONJ", "SCONJ", "ADP", "AUX", "PART", "PUNCT", "NUM"} + SEMANTIC_THRESHOLD = 0.82 # Minimum cosine similarity to accept substitution + + def __init__( + self, + awl_path: str = "data/awl/coxhead_awl.txt", + spacy_model: str = "en_core_web_trf", + mlm_model: str = "bert-large-uncased", + sem_model: str = "all-mpnet-base-v2", + ): + # Load spaCy + try: + self.nlp = spacy.load(spacy_model) + except OSError: + logger.warning(f"spaCy model '{spacy_model}' not found, falling back to 'en_core_web_sm'") + self.nlp = spacy.load("en_core_web_sm") + + # Load AWL + self.awl = AWLLoader(primary_path=awl_path) + + # Load BERT fill-mask pipeline (CPU-optimised) + logger.info(f"Loading MLM model: {mlm_model}") + self.fill_mask = hf_pipeline( + "fill-mask", + model=mlm_model, + device=-1, # Force CPU + top_k=15, + ) + + # Load sentence transformer for semantic similarity + logger.info(f"Loading sentence transformer: {sem_model}") + self.sem_model = SentenceTransformer(sem_model, device="cpu") + logger.info("LexicalElevator initialised") + + def _sem_similarity(self, word_a: str, word_b: str, context: str) -> float: + """Compute contextual semantic similarity using sentence embeddings.""" + # Create contextualised versions + sent_a = context.replace(word_a, word_a, 1) + sent_b = context.replace(word_a, word_b, 1) + + embeddings = self.sem_model.encode([sent_a, sent_b], convert_to_tensor=True) + sim = F.cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0)) + return sim.item() + + def _get_awl_substitutions(self, sentence: str, word: str, pos: str) -> List[str]: + """Generate candidate AWL substitutions using BERT fill-mask.""" + # Replace the word with [MASK] token + masked = sentence.replace(word, self.fill_mask.tokenizer.mask_token, 1) + + try: + predictions = self.fill_mask(masked) + except Exception as e: + logger.debug(f"Fill-mask failed for '{word}': {e}") + return [] + + # Filter to AWL words only + candidates = [] + for pred in predictions: + token = pred["token_str"].strip().lower() + if (self.awl.is_academic(token) and + token != word.lower() and + len(token) > 2 and + pred["score"] > 0.01): + candidates.append(token) + + return candidates + + # High-register words that should NOT be substituted even if not in Coxhead AWL + ALREADY_ACADEMIC = { + "resilient", "adaptive", "indifferent", "staggering", "improbably", + "simultaneously", "harboring", "hurtling", "unfold", "cosmos", + "catastrophic", "ubiquitous", "paradox", "nuanced", "inherent", + "exacerbate", "paradigm", "juxtapose", "dichotomy", "efficacy", + } + + def elevate(self, text: str, protected_spans: List[Tuple[int, int]] = None) -> str: + """Main entry point: elevates vocabulary to academic register.""" + if not text or not text.strip(): + return text + + if protected_spans is None: + protected_spans = [] + + doc = self.nlp(text) + result = text + used_substitutions = set() # Prevent the same word being used as replacement twice + + # Process each sentence independently + for sent in doc.sents: + for token in sent: + # Skip protected POS + if token.pos_ in self.PROTECTED_POS: + continue + + # Skip short words and stop words + if len(token.text) < 4 or token.is_stop: + continue + + # Skip words already in AWL + if self.awl.is_academic(token.text): + continue + + # Skip words that are already academic-register + if token.text.lower() in self.ALREADY_ACADEMIC: + continue + + # Skip if in protected span + if any(start <= token.idx < end for start, end in protected_spans): + continue + + # Only elevate content words + if token.pos_ not in ("NOUN", "VERB", "ADJ", "ADV"): + continue + + # Get AWL candidates + candidates = self._get_awl_substitutions(sent.text, token.text, token.pos_) + + # Find best candidate above similarity threshold + best_candidate = None + best_sim = self.SEMANTIC_THRESHOLD + + for candidate in candidates: + # Skip if this substitution word was already used + if candidate.lower() in used_substitutions: + continue + sim = self._sem_similarity(token.text, candidate, sent.text) + if sim > best_sim: + best_sim = sim + best_candidate = candidate + + if best_candidate: + # Preserve capitalisation + if token.text[0].isupper(): + best_candidate = best_candidate.capitalize() + + # Track this substitution to prevent duplicates + used_substitutions.add(best_candidate.lower()) + + # Replace in result (first occurrence in this context) + result = result.replace(token.text, best_candidate, 1) + logger.debug(f"Elevated: '{token.text}' → '{best_candidate}' (sim={best_sim:.3f})") + + return result + + +class RegisterFilter: + """ + Applies register-level corrections to ensure academic tone: + - Converts contractions to full forms + - Ensures hedging where appropriate + - Flags over-colloquial phrases for review + """ + + CONTRACTIONS = { + "don't": "do not", "can't": "cannot", "won't": "will not", + "it's": "it is", "that's": "that is", "there's": "there is", + "they're": "they are", "we're": "we are", "you're": "you are", + "I'm": "I am", "I've": "I have", "I'll": "I will", + "isn't": "is not", "aren't": "are not", "wasn't": "was not", + "weren't": "were not", "hasn't": "has not", "haven't": "have not", + "couldn't": "could not", "wouldn't": "would not", "shouldn't": "should not", + } + + COLLOQUIAL_TO_ACADEMIC = { + "a lot of": "a substantial number of", + "lots of": "numerous", + "big": "substantial", + "get": "obtain", + "show": "demonstrate", + "use": "utilise", + "find out": "ascertain", + "look at": "examine", + "think about": "consider", + "talk about": "discuss", + "deal with": "address", + "carry out": "conduct", + "point out": "indicate", + "make sure": "ensure", + "come up with": "develop", + "go up": "increase", + "go down": "decrease", + "start": "commence", + "end": "conclude", + "help": "facilitate", + "need": "require", + "try": "attempt", + "want": "seek", + } + + def apply(self, text: str) -> str: + """Apply contraction expansion and colloquial-to-academic substitution.""" + if not text or not text.strip(): + return text + + result = text + + # Step 1: Expand contractions (case-insensitive) + for contraction, expansion in self.CONTRACTIONS.items(): + # Match with word boundaries to avoid partial replacements + pattern = re.compile(re.escape(contraction), re.IGNORECASE) + result = pattern.sub(expansion, result) + + # Step 2: Replace colloquial phrases (longer phrases first to avoid partial matches) + sorted_colloquials = sorted( + self.COLLOQUIAL_TO_ACADEMIC.items(), + key=lambda x: len(x[0]), + reverse=True, + ) + for colloquial, academic in sorted_colloquials: + # Word boundary matching to avoid replacing within words + pattern = re.compile(r'\b' + re.escape(colloquial) + r'\b', re.IGNORECASE) + result = pattern.sub(academic, result) + + return result diff --git a/src/vocabulary/register_filter.py b/src/vocabulary/register_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..d231b38e9a7fe56477a250b0d404d5cf6c974688 --- /dev/null +++ b/src/vocabulary/register_filter.py @@ -0,0 +1,100 @@ +""" +Register filter module. +Applies register-level post-processing to ensure output text +meets academic formality requirements. +""" + +import re +from typing import List +from loguru import logger + + +class RegisterFilterAdvanced: + """Advanced register filtering with nominalisation and hedging passes.""" + + # Verb → Nominal form mappings + VERB_TO_NOMINAL = { + "analyse": "analysis", "analyze": "analysis", + "argue": "argument", "assess": "assessment", + "assume": "assumption", "classify": "classification", + "communicate": "communication", "conclude": "conclusion", + "contribute": "contribution", "create": "creation", + "decide": "decision", "define": "definition", + "describe": "description", "develop": "development", + "discuss": "discussion", "distribute": "distribution", + "educate": "education", "emphasise": "emphasis", + "emphasize": "emphasis", "establish": "establishment", + "evaluate": "evaluation", "examine": "examination", + "explain": "explanation", "explore": "exploration", + "identify": "identification", "implement": "implementation", + "improve": "improvement", "interpret": "interpretation", + "introduce": "introduction", "investigate": "investigation", + "involve": "involvement", "motivate": "motivation", + "observe": "observation", "participate": "participation", + "produce": "production", "recommend": "recommendation", + "regulate": "regulation", "respond": "response", + } + + # Absolute claim markers that should be hedged — only genuinely absolute claims. + # Deliberately excludes common verbs like "is", "are", "shows" which would + # destroy normal sentences if replaced. + ABSOLUTE_MARKERS = [ + (r'\bproves\s+that\b', "suggests that"), + (r'\bclearly\s+demonstrates\b', "arguably demonstrates"), + (r'\balways\b', "typically"), + (r'\bnever\b', "rarely"), + (r'\bcertainly\b', "likely"), + (r'\bobviously\b', "evidently"), + (r'\bundoubtedly\b', "presumably"), + (r'\bwithout\s+a\s+doubt\b', "in all likelihood"), + (r'\bit\s+is\s+clear\s+that\b', "it appears that"), + (r'\bthere\s+is\s+no\s+doubt\b', "there is strong evidence"), + ] + + def __init__(self, min_formality: float = 0.65): + self.min_formality = min_formality + + def nominalise(self, text: str) -> str: + """Convert verbal phrases to nominal forms where appropriate. + + Only applies to clear cases where nominalisation improves academic register + without changing meaning. Applied conservatively. + """ + if not text or not text.strip(): + return text + + result = text + for verb, nominal in self.VERB_TO_NOMINAL.items(): + # Match "the [verb]ing of" → "the [nominal] of" + gerund = verb.rstrip("e") + "ing" if verb.endswith("e") else verb + "ing" + pattern = re.compile(r'\bthe\s+' + re.escape(gerund) + r'\s+of\b', re.IGNORECASE) + replacement = f"the {nominal} of" + result = pattern.sub(replacement, result) + + return result + + def add_hedging(self, text: str) -> str: + """Add hedging language where claims are too absolute. + + Only applies to the first occurrence of each absolute marker + to avoid over-hedging. + """ + if not text or not text.strip(): + return text + + result = text + for pattern_str, hedge in self.ABSOLUTE_MARKERS: + pattern = re.compile(pattern_str, re.IGNORECASE) + # Only replace first occurrence to avoid over-hedging + result = pattern.sub(hedge, result, count=1) + + return result + + def check_formality(self, text: str) -> float: + """Score text formality on 0-1 scale.""" + if not text or not text.strip(): + return 0.5 + + from src.style.formality_classifier import FormalityClassifier + classifier = FormalityClassifier() + return classifier.score(text)