Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- scripts/__pycache__/evaluate.cpython-312.pyc +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/__init__.cpython-314.pyc +0 -0
- src/api/__init__.py +0 -0
- src/api/__pycache__/main.cpython-312.pyc +0 -0
- src/api/__pycache__/middleware.cpython-312.pyc +0 -0
- src/api/__pycache__/schemas.cpython-312.pyc +0 -0
- src/api/middleware.py +67 -0
- src/api/schemas.py +21 -0
- src/evaluation/__init__.py +0 -0
- src/evaluation/__pycache__/__init__.cpython-314.pyc +0 -0
- src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc +0 -0
- src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc +0 -0
- src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc +0 -0
- src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc +0 -0
- src/evaluation/__pycache__/style_metrics.cpython-312.pyc +0 -0
- src/evaluation/__pycache__/style_metrics.cpython-314.pyc +0 -0
- src/evaluation/authorship_verifier.py +50 -0
- src/evaluation/errant_evaluator.py +82 -0
- src/evaluation/gleu_scorer.py +68 -0
- src/evaluation/style_metrics.py +81 -0
- src/inference/__init__.py +0 -0
- src/inference/__pycache__/__init__.cpython-314.pyc +0 -0
- src/inference/__pycache__/corrector.cpython-312.pyc +0 -0
- src/inference/__pycache__/corrector.cpython-314.pyc +0 -0
- src/inference/__pycache__/postprocessor.cpython-312.pyc +0 -0
- src/inference/__pycache__/postprocessor.cpython-314.pyc +0 -0
- src/inference/corrector.py +283 -0
- src/inference/postprocessor.py +118 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-312.pyc +0 -0
- src/model/__pycache__/__init__.cpython-314.pyc +0 -0
- src/model/__pycache__/base_model.cpython-312.pyc +0 -0
- src/model/__pycache__/base_model.cpython-314.pyc +0 -0
- src/model/__pycache__/generation_utils.cpython-312.pyc +0 -0
- src/model/__pycache__/generation_utils.cpython-314.pyc +0 -0
- src/model/__pycache__/lora_adapter.cpython-312.pyc +0 -0
- src/model/__pycache__/style_conditioner.cpython-312.pyc +0 -0
- src/model/__pycache__/style_conditioner.cpython-314.pyc +0 -0
- src/model/base_model.py +135 -0
- src/model/generation_utils.py +106 -0
- src/model/lora_adapter.py +54 -0
- src/model/style_conditioner.py +74 -0
- src/preprocessing/__init__.py +0 -0
- src/preprocessing/__pycache__/__init__.cpython-312.pyc +0 -0
- src/preprocessing/__pycache__/__init__.cpython-314.pyc +0 -0
- src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc +0 -0
- src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc +0 -0
- src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc +0 -0
- src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc +0 -0
scripts/__pycache__/evaluate.cpython-312.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
src/api/__init__.py
ADDED
|
File without changes
|
src/api/__pycache__/main.cpython-312.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
src/api/__pycache__/middleware.cpython-312.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
src/api/__pycache__/schemas.cpython-312.pyc
ADDED
|
Binary file (1.38 kB). View file
|
|
|
src/api/middleware.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API middleware for request logging, rate limiting, and error handling.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import Request
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import time
|
| 10 |
+
from collections import defaultdict, deque
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| 14 |
+
"""Logs all incoming requests with timing information."""
|
| 15 |
+
|
| 16 |
+
async def dispatch(self, request: Request, call_next):
|
| 17 |
+
start_time = time.time()
|
| 18 |
+
path = request.url.path
|
| 19 |
+
method = request.method
|
| 20 |
+
|
| 21 |
+
logger.info(f"→ {method} {path}")
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
response = await call_next(request)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.error(f"✗ {method} {path} - Error: {e}")
|
| 27 |
+
raise
|
| 28 |
+
|
| 29 |
+
elapsed = (time.time() - start_time) * 1000 # ms
|
| 30 |
+
logger.info(f"← {method} {path} - {response.status_code} ({elapsed:.1f}ms)")
|
| 31 |
+
|
| 32 |
+
return response
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 36 |
+
"""Simple in-memory rate limiting."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, app, max_requests_per_minute: int = 60):
|
| 39 |
+
super().__init__(app)
|
| 40 |
+
self.max_requests = max_requests_per_minute
|
| 41 |
+
self.window = 60 # seconds
|
| 42 |
+
# Track requests per client IP: {ip: deque([timestamp, ...])}
|
| 43 |
+
self.requests: dict = defaultdict(deque)
|
| 44 |
+
|
| 45 |
+
async def dispatch(self, request: Request, call_next):
|
| 46 |
+
# Get client IP
|
| 47 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 48 |
+
now = time.time()
|
| 49 |
+
|
| 50 |
+
# Clean old entries
|
| 51 |
+
timestamps = self.requests[client_ip]
|
| 52 |
+
while timestamps and timestamps[0] < now - self.window:
|
| 53 |
+
timestamps.popleft()
|
| 54 |
+
|
| 55 |
+
# Check rate limit
|
| 56 |
+
if len(timestamps) >= self.max_requests:
|
| 57 |
+
logger.warning(f"Rate limited: {client_ip} ({len(timestamps)} requests in {self.window}s)")
|
| 58 |
+
return JSONResponse(
|
| 59 |
+
status_code=429,
|
| 60 |
+
content={"detail": "Rate limit exceeded. Please wait before making more requests."},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Record this request
|
| 64 |
+
timestamps.append(now)
|
| 65 |
+
|
| 66 |
+
response = await call_next(request)
|
| 67 |
+
return response
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for API request/response validation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import Optional, Dict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CorrectionRequest(BaseModel):
|
| 10 |
+
text: str = Field(..., min_length=10, max_length=5000, description="Raw dyslectic text to correct.")
|
| 11 |
+
master_copy: Optional[str] = Field(None, description="Optional master copy to match style toward.")
|
| 12 |
+
style_alpha: float = Field(0.6, ge=0.0, le=1.0, description="Weight given to user's own style (0=full master, 1=full user).")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CorrectionResponse(BaseModel):
|
| 16 |
+
original: str
|
| 17 |
+
corrected: str
|
| 18 |
+
style_similarity: float
|
| 19 |
+
awl_coverage: float
|
| 20 |
+
readability: Dict[str, float]
|
| 21 |
+
changes_summary: str
|
src/evaluation/__init__.py
ADDED
|
File without changes
|
src/evaluation/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (182 Bytes). View file
|
|
|
src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
src/evaluation/__pycache__/style_metrics.cpython-312.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
src/evaluation/__pycache__/style_metrics.cpython-314.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
src/evaluation/authorship_verifier.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authorship verification module.
|
| 3 |
+
Uses a fine-tuned model to verify whether the corrected output
|
| 4 |
+
could plausibly have been written by the same author as the input.
|
| 5 |
+
Target: > 0.80 same-author probability.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from loguru import logger
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AuthorshipVerifier:
|
| 15 |
+
"""Verifies authorship consistency between input and output text."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str = "roberta-base"):
|
| 18 |
+
try:
|
| 19 |
+
from sentence_transformers import SentenceTransformer
|
| 20 |
+
self.model = SentenceTransformer(model_name, device="cpu")
|
| 21 |
+
logger.info(f"AuthorshipVerifier loaded with {model_name}")
|
| 22 |
+
except Exception as e:
|
| 23 |
+
logger.warning(f"Failed to load authorship model: {e}")
|
| 24 |
+
self.model = None
|
| 25 |
+
|
| 26 |
+
def verify(self, text_a: str, text_b: str) -> float:
|
| 27 |
+
"""Return probability that both texts were written by the same author.
|
| 28 |
+
|
| 29 |
+
Uses sentence embedding similarity as a proxy for authorship.
|
| 30 |
+
Higher cosine similarity suggests same author.
|
| 31 |
+
"""
|
| 32 |
+
if self.model is None:
|
| 33 |
+
return 0.5 # Neutral score if model unavailable
|
| 34 |
+
|
| 35 |
+
if not text_a or not text_b:
|
| 36 |
+
return 0.5
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
embeddings = self.model.encode([text_a, text_b], convert_to_tensor=True)
|
| 40 |
+
sim = F.cosine_similarity(
|
| 41 |
+
embeddings[0].unsqueeze(0),
|
| 42 |
+
embeddings[1].unsqueeze(0),
|
| 43 |
+
)
|
| 44 |
+
# Scale similarity to [0, 1] probability
|
| 45 |
+
# Cosine similarity is already in [-1, 1], shift to [0, 1]
|
| 46 |
+
prob = (sim.item() + 1.0) / 2.0
|
| 47 |
+
return prob
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.warning(f"Authorship verification failed: {e}")
|
| 50 |
+
return 0.5
|
src/evaluation/errant_evaluator.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ERRANT-based grammatical error evaluation.
|
| 3 |
+
Uses the ERRANT toolkit for standardised GEC evaluation with
|
| 4 |
+
precision, recall, and F0.5 scores.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ERRANTEvaluator:
|
| 12 |
+
"""Evaluates grammar correction quality using ERRANT annotations."""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
try:
|
| 16 |
+
import errant
|
| 17 |
+
self.annotator = errant.load("en")
|
| 18 |
+
logger.info("ERRANT annotator loaded")
|
| 19 |
+
except Exception as e:
|
| 20 |
+
logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.")
|
| 21 |
+
self.annotator = None
|
| 22 |
+
|
| 23 |
+
def evaluate(
|
| 24 |
+
self,
|
| 25 |
+
sources: List[str],
|
| 26 |
+
predictions: List[str],
|
| 27 |
+
references: List[str],
|
| 28 |
+
) -> Dict[str, float]:
|
| 29 |
+
"""Compute ERRANT precision, recall, F0.5."""
|
| 30 |
+
if self.annotator is None:
|
| 31 |
+
logger.warning("ERRANT not available, returning zero scores")
|
| 32 |
+
return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0}
|
| 33 |
+
|
| 34 |
+
tp = 0
|
| 35 |
+
fp = 0
|
| 36 |
+
fn = 0
|
| 37 |
+
|
| 38 |
+
for src, pred, ref in zip(sources, predictions, references):
|
| 39 |
+
try:
|
| 40 |
+
# Parse source and annotate edits
|
| 41 |
+
orig = self.annotator.parse(src)
|
| 42 |
+
cor_pred = self.annotator.parse(pred)
|
| 43 |
+
cor_ref = self.annotator.parse(ref)
|
| 44 |
+
|
| 45 |
+
# Get edit annotations
|
| 46 |
+
pred_edits = self.annotator.annotate(orig, cor_pred)
|
| 47 |
+
ref_edits = self.annotator.annotate(orig, cor_ref)
|
| 48 |
+
|
| 49 |
+
# Convert to comparable sets of (start, end, correction, type)
|
| 50 |
+
pred_set = set()
|
| 51 |
+
for e in pred_edits:
|
| 52 |
+
pred_set.add((e.o_start, e.o_end, e.c_str, e.type))
|
| 53 |
+
|
| 54 |
+
ref_set = set()
|
| 55 |
+
for e in ref_edits:
|
| 56 |
+
ref_set.add((e.o_start, e.o_end, e.c_str, e.type))
|
| 57 |
+
|
| 58 |
+
# Count TP, FP, FN
|
| 59 |
+
tp += len(pred_set & ref_set)
|
| 60 |
+
fp += len(pred_set - ref_set)
|
| 61 |
+
fn += len(ref_set - pred_set)
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.debug(f"ERRANT annotation failed for a sample: {e}")
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
# Compute precision, recall, F0.5
|
| 68 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 69 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 70 |
+
|
| 71 |
+
# F0.5 weighs precision higher than recall (β=0.5)
|
| 72 |
+
beta = 0.5
|
| 73 |
+
if precision + recall > 0:
|
| 74 |
+
f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
|
| 75 |
+
else:
|
| 76 |
+
f_score = 0.0
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"precision": precision,
|
| 80 |
+
"recall": recall,
|
| 81 |
+
"f0.5": f_score,
|
| 82 |
+
}
|
src/evaluation/gleu_scorer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GLEU (Generalized Language Evaluation Understanding) score.
|
| 3 |
+
Preferred over BLEU for grammatical error correction tasks.
|
| 4 |
+
Also computes BERTScore for semantic similarity evaluation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sacrebleu
|
| 8 |
+
from bert_score import score as bert_score_fn
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GLEUScorer:
|
| 14 |
+
"""Computes GLEU and BERTScore metrics for GEC evaluation."""
|
| 15 |
+
|
| 16 |
+
def compute_gleu(
|
| 17 |
+
self,
|
| 18 |
+
predictions: List[str],
|
| 19 |
+
references: List[str],
|
| 20 |
+
) -> float:
|
| 21 |
+
"""Corpus-level GLEU score (0-100).
|
| 22 |
+
|
| 23 |
+
GLEU is the geometric mean of n-gram precisions and recall,
|
| 24 |
+
preferred over BLEU for GEC because it equally penalises
|
| 25 |
+
both under-correction and over-correction.
|
| 26 |
+
"""
|
| 27 |
+
if not predictions or not references:
|
| 28 |
+
return 0.0
|
| 29 |
+
|
| 30 |
+
# sacrebleu expects references as a list of lists
|
| 31 |
+
refs = [references]
|
| 32 |
+
|
| 33 |
+
# Use BLEU with smoothing as GLEU approximation
|
| 34 |
+
# sacrebleu doesn't have a native GLEU, so we use smoothed BLEU
|
| 35 |
+
bleu = sacrebleu.corpus_bleu(
|
| 36 |
+
predictions,
|
| 37 |
+
refs,
|
| 38 |
+
smooth_method="exp",
|
| 39 |
+
smooth_value=0.1,
|
| 40 |
+
)
|
| 41 |
+
return bleu.score
|
| 42 |
+
|
| 43 |
+
def compute_bert_score(
|
| 44 |
+
self,
|
| 45 |
+
predictions: List[str],
|
| 46 |
+
references: List[str],
|
| 47 |
+
lang: str = "en",
|
| 48 |
+
) -> Tuple[float, float, float]:
|
| 49 |
+
"""Returns (precision, recall, F1) as averages over the batch."""
|
| 50 |
+
if not predictions or not references:
|
| 51 |
+
return (0.0, 0.0, 0.0)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
P, R, F1 = bert_score_fn(
|
| 55 |
+
predictions,
|
| 56 |
+
references,
|
| 57 |
+
lang=lang,
|
| 58 |
+
verbose=False,
|
| 59 |
+
device="cpu", # CPU-optimised
|
| 60 |
+
)
|
| 61 |
+
return (
|
| 62 |
+
P.mean().item(),
|
| 63 |
+
R.mean().item(),
|
| 64 |
+
F1.mean().item(),
|
| 65 |
+
)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.warning(f"BERTScore computation failed: {e}")
|
| 68 |
+
return (0.0, 0.0, 0.0)
|
src/evaluation/style_metrics.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Measures style preservation between input and output.
|
| 3 |
+
|
| 4 |
+
Key metrics:
|
| 5 |
+
- Style Vector Cosine Similarity (target: > 0.85)
|
| 6 |
+
- AWL Coverage Score (target: > 0.25)
|
| 7 |
+
- Authorship Verification Score (target: > 0.80)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import List, Tuple
|
| 13 |
+
from ..style.fingerprinter import StyleFingerprinter
|
| 14 |
+
from ..vocabulary.awl_loader import AWLLoader
|
| 15 |
+
from loguru import logger
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StyleEvaluator:
|
| 20 |
+
"""Evaluates style preservation and academic vocabulary coverage."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader):
|
| 23 |
+
self.fingerprinter = fingerprinter
|
| 24 |
+
self.awl = awl
|
| 25 |
+
|
| 26 |
+
def style_similarity(self, text_a: str, text_b: str) -> float:
|
| 27 |
+
"""Cosine similarity between style vectors. Target: > 0.85."""
|
| 28 |
+
vec_a = self.fingerprinter.extract_vector(text_a)
|
| 29 |
+
vec_b = self.fingerprinter.extract_vector(text_b)
|
| 30 |
+
|
| 31 |
+
if vec_a.dim() == 1:
|
| 32 |
+
vec_a = vec_a.unsqueeze(0)
|
| 33 |
+
if vec_b.dim() == 1:
|
| 34 |
+
vec_b = vec_b.unsqueeze(0)
|
| 35 |
+
|
| 36 |
+
sim = F.cosine_similarity(vec_a, vec_b, dim=-1)
|
| 37 |
+
return sim.item()
|
| 38 |
+
|
| 39 |
+
def awl_coverage(self, text: str) -> float:
|
| 40 |
+
"""Fraction of content words in AWL. Target: > 0.25."""
|
| 41 |
+
if not text or not text.strip():
|
| 42 |
+
return 0.0
|
| 43 |
+
|
| 44 |
+
words = text.lower().split()
|
| 45 |
+
# Filter to content words (longer than 3 chars, alphabetic)
|
| 46 |
+
content_words = [w for w in words if len(w) > 3 and w.isalpha()]
|
| 47 |
+
|
| 48 |
+
if not content_words:
|
| 49 |
+
return 0.0
|
| 50 |
+
|
| 51 |
+
awl_count = sum(1 for w in content_words if self.awl.is_academic(w))
|
| 52 |
+
return awl_count / len(content_words)
|
| 53 |
+
|
| 54 |
+
def evaluate_batch(
|
| 55 |
+
self,
|
| 56 |
+
inputs: List[str],
|
| 57 |
+
outputs: List[str],
|
| 58 |
+
references: List[str],
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""Compute style and AWL metrics for a batch."""
|
| 61 |
+
style_sims = []
|
| 62 |
+
awl_coverages = []
|
| 63 |
+
ref_style_sims = []
|
| 64 |
+
|
| 65 |
+
for inp, out, ref in zip(inputs, outputs, references):
|
| 66 |
+
# Style similarity between input and output (preservation)
|
| 67 |
+
style_sims.append(self.style_similarity(inp, out))
|
| 68 |
+
|
| 69 |
+
# AWL coverage of output
|
| 70 |
+
awl_coverages.append(self.awl_coverage(out))
|
| 71 |
+
|
| 72 |
+
# Style similarity between output and reference
|
| 73 |
+
ref_style_sims.append(self.style_similarity(out, ref))
|
| 74 |
+
|
| 75 |
+
return {
|
| 76 |
+
"style_similarity_mean": float(np.mean(style_sims)),
|
| 77 |
+
"style_similarity_std": float(np.std(style_sims)),
|
| 78 |
+
"awl_coverage_mean": float(np.mean(awl_coverages)),
|
| 79 |
+
"awl_coverage_std": float(np.std(awl_coverages)),
|
| 80 |
+
"ref_style_similarity_mean": float(np.mean(ref_style_sims)),
|
| 81 |
+
}
|
src/inference/__init__.py
ADDED
|
File without changes
|
src/inference/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
src/inference/__pycache__/corrector.cpython-312.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
src/inference/__pycache__/corrector.cpython-314.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
src/inference/__pycache__/postprocessor.cpython-312.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|
src/inference/__pycache__/postprocessor.cpython-314.pyc
ADDED
|
Binary file (5.69 kB). View file
|
|
|
src/inference/corrector.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-end inference pipeline.
|
| 3 |
+
Accepts raw dyslectic text (and optionally a master copy),
|
| 4 |
+
returns corrected academic text with metadata.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from ..preprocessing.pipeline import PreprocessingPipeline
|
| 8 |
+
from ..style.fingerprinter import StyleFingerprinter
|
| 9 |
+
from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter
|
| 10 |
+
from ..model.base_model import load_model_and_tokenizer
|
| 11 |
+
from ..model.style_conditioner import StyleConditioner, prepend_style_prefix
|
| 12 |
+
from ..model.generation_utils import generate_correction
|
| 13 |
+
from .postprocessor import PostProcessor
|
| 14 |
+
from ..evaluation.style_metrics import StyleEvaluator
|
| 15 |
+
from ..vocabulary.awl_loader import AWLLoader
|
| 16 |
+
import torch
|
| 17 |
+
from typing import Optional
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from loguru import logger
|
| 20 |
+
import yaml
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
TASK_PREFIX = (
|
| 24 |
+
"Correct the following text for grammar, spelling, and clarity. "
|
| 25 |
+
"Maintain the author's original tone and writing style. "
|
| 26 |
+
"Elevate vocabulary to academic register. "
|
| 27 |
+
"Do NOT change the meaning or add new information. "
|
| 28 |
+
"Preserve named entities exactly. "
|
| 29 |
+
"Text to correct: "
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class CorrectionResult:
|
| 35 |
+
original: str
|
| 36 |
+
corrected: str
|
| 37 |
+
preprocessed: str
|
| 38 |
+
style_similarity: float
|
| 39 |
+
awl_coverage: float
|
| 40 |
+
readability: dict
|
| 41 |
+
changes_summary: str
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AcademicCorrector:
|
| 45 |
+
"""Full inference pipeline: preprocess → fingerprint → generate → elevate → filter."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, config: dict):
|
| 48 |
+
logger.info("Initialising AcademicCorrector...")
|
| 49 |
+
|
| 50 |
+
model_cfg = config.get("model", {})
|
| 51 |
+
gen_cfg = config.get("generation", {})
|
| 52 |
+
vocab_cfg = config.get("vocabulary", {})
|
| 53 |
+
style_cfg = config.get("style_conditioner", {})
|
| 54 |
+
|
| 55 |
+
# 1. Load model and tokenizer
|
| 56 |
+
model_key = model_cfg.get("key", "flan-t5-small")
|
| 57 |
+
checkpoint = model_cfg.get("checkpoint_path", None)
|
| 58 |
+
use_lora = model_cfg.get("use_lora", False)
|
| 59 |
+
|
| 60 |
+
if checkpoint and use_lora:
|
| 61 |
+
# PEFT adapter checkpoint: load base model + apply adapter
|
| 62 |
+
import os
|
| 63 |
+
try:
|
| 64 |
+
from peft import PeftModel
|
| 65 |
+
logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'")
|
| 66 |
+
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
|
| 67 |
+
model_key, quantize=False, use_lora=False
|
| 68 |
+
)
|
| 69 |
+
self.model = PeftModel.from_pretrained(self.model, checkpoint)
|
| 70 |
+
logger.info(f"PEFT adapter loaded from {checkpoint}")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"PEFT loading failed ({e}), loading base model only")
|
| 73 |
+
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
|
| 74 |
+
model_key, quantize=False, use_lora=False
|
| 75 |
+
)
|
| 76 |
+
elif checkpoint:
|
| 77 |
+
# Full model checkpoint (merged weights)
|
| 78 |
+
try:
|
| 79 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 80 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
| 81 |
+
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
| 82 |
+
self.is_seq2seq = True
|
| 83 |
+
logger.info(f"Loaded full model from checkpoint: {checkpoint}")
|
| 84 |
+
except Exception:
|
| 85 |
+
logger.warning(f"Checkpoint not found, loading base model: {model_key}")
|
| 86 |
+
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
|
| 87 |
+
model_key, quantize=False, use_lora=False
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
|
| 91 |
+
model_key, quantize=False, use_lora=False
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.model.eval()
|
| 95 |
+
self.generation_config = gen_cfg
|
| 96 |
+
|
| 97 |
+
# 2. Preprocessor
|
| 98 |
+
self.preprocessor = PreprocessingPipeline()
|
| 99 |
+
|
| 100 |
+
# 3. Style fingerprinter
|
| 101 |
+
fp_cfg = config.get("fingerprinter", {})
|
| 102 |
+
self.fingerprinter = StyleFingerprinter(
|
| 103 |
+
spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"),
|
| 104 |
+
awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 4. Style conditioner — auto-detect hidden dim from loaded model
|
| 108 |
+
if hasattr(self.model.config, "d_model"):
|
| 109 |
+
auto_hidden_dim = self.model.config.d_model
|
| 110 |
+
elif hasattr(self.model.config, "hidden_size"):
|
| 111 |
+
auto_hidden_dim = self.model.config.hidden_size
|
| 112 |
+
else:
|
| 113 |
+
auto_hidden_dim = 512 # Safe default for T5-Small
|
| 114 |
+
logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}")
|
| 115 |
+
|
| 116 |
+
self.conditioner = StyleConditioner(
|
| 117 |
+
style_dim=style_cfg.get("style_dim", 512),
|
| 118 |
+
model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim),
|
| 119 |
+
n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10),
|
| 120 |
+
)
|
| 121 |
+
self.conditioner.eval()
|
| 122 |
+
|
| 123 |
+
# 5. Vocabulary elevator
|
| 124 |
+
try:
|
| 125 |
+
self.elevator = LexicalElevator(
|
| 126 |
+
awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
|
| 127 |
+
spacy_model="en_core_web_sm",
|
| 128 |
+
mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"),
|
| 129 |
+
sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"),
|
| 130 |
+
)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.warning(f"Lexical elevator init failed: {e}, elevation disabled")
|
| 133 |
+
self.elevator = None
|
| 134 |
+
|
| 135 |
+
# 6. Register filter
|
| 136 |
+
self.register_filter = RegisterFilter()
|
| 137 |
+
|
| 138 |
+
# 7. Post-processor
|
| 139 |
+
self.postprocessor = PostProcessor()
|
| 140 |
+
|
| 141 |
+
# 8. Evaluator
|
| 142 |
+
awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"))
|
| 143 |
+
self.evaluator = StyleEvaluator(self.fingerprinter, awl)
|
| 144 |
+
|
| 145 |
+
logger.info("AcademicCorrector initialised successfully")
|
| 146 |
+
|
| 147 |
+
def correct(
|
| 148 |
+
self,
|
| 149 |
+
raw_text: str,
|
| 150 |
+
master_copy: Optional[str] = None,
|
| 151 |
+
style_alpha: float = 0.6,
|
| 152 |
+
) -> CorrectionResult:
|
| 153 |
+
"""
|
| 154 |
+
Full correction pipeline:
|
| 155 |
+
1. Pre-process (spell correct + parse)
|
| 156 |
+
2. Style fingerprint
|
| 157 |
+
3. Generate with style conditioning
|
| 158 |
+
4. Academic vocabulary elevation
|
| 159 |
+
5. Register filter
|
| 160 |
+
6. Compute quality metrics
|
| 161 |
+
"""
|
| 162 |
+
# Step 1: Pre-process
|
| 163 |
+
logger.info("Step 1: Preprocessing...")
|
| 164 |
+
doc = self.preprocessor.process(raw_text)
|
| 165 |
+
|
| 166 |
+
# Step 2: Style fingerprint
|
| 167 |
+
logger.info("Step 2: Extracting style fingerprint...")
|
| 168 |
+
user_style = self.fingerprinter.extract_vector(doc.corrected_text)
|
| 169 |
+
|
| 170 |
+
if master_copy:
|
| 171 |
+
master_style = self.fingerprinter.extract_vector(master_copy)
|
| 172 |
+
target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha)
|
| 173 |
+
else:
|
| 174 |
+
target_style = user_style
|
| 175 |
+
|
| 176 |
+
# Step 3: Generate correction (sentence-chunked)
|
| 177 |
+
# The model was trained on max_input_length=128 tokens.
|
| 178 |
+
# Split text into sentence groups that fit within that window,
|
| 179 |
+
# process each chunk, then reassemble.
|
| 180 |
+
logger.info("Step 3: Generating correction (chunked)...")
|
| 181 |
+
|
| 182 |
+
MAX_INPUT_TOKENS = 128
|
| 183 |
+
# Measure how many tokens the task prefix uses
|
| 184 |
+
prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False))
|
| 185 |
+
budget = MAX_INPUT_TOKENS - prefix_tokens - 2 # 2 for special tokens
|
| 186 |
+
|
| 187 |
+
# Split into sentences using spaCy (already loaded for fingerprinting)
|
| 188 |
+
sent_doc = self.fingerprinter.nlp(doc.corrected_text)
|
| 189 |
+
sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()]
|
| 190 |
+
|
| 191 |
+
# Group sentences into chunks that fit the token budget
|
| 192 |
+
chunks = []
|
| 193 |
+
current_chunk = []
|
| 194 |
+
current_tokens = 0
|
| 195 |
+
|
| 196 |
+
for sent in sentences:
|
| 197 |
+
sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False))
|
| 198 |
+
if current_tokens + sent_tokens > budget and current_chunk:
|
| 199 |
+
chunks.append(" ".join(current_chunk))
|
| 200 |
+
current_chunk = [sent]
|
| 201 |
+
current_tokens = sent_tokens
|
| 202 |
+
else:
|
| 203 |
+
current_chunk.append(sent)
|
| 204 |
+
current_tokens += sent_tokens
|
| 205 |
+
|
| 206 |
+
if current_chunk:
|
| 207 |
+
chunks.append(" ".join(current_chunk))
|
| 208 |
+
|
| 209 |
+
logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences")
|
| 210 |
+
|
| 211 |
+
corrected_chunks = []
|
| 212 |
+
device = next(self.model.parameters()).device
|
| 213 |
+
|
| 214 |
+
for i, chunk in enumerate(chunks):
|
| 215 |
+
chunk_input = TASK_PREFIX + chunk
|
| 216 |
+
inputs = self.tokenizer(
|
| 217 |
+
chunk_input,
|
| 218 |
+
max_length=MAX_INPUT_TOKENS,
|
| 219 |
+
truncation=True,
|
| 220 |
+
return_tensors="pt",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
input_ids = inputs["input_ids"].to(device)
|
| 224 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 225 |
+
|
| 226 |
+
chunk_output = generate_correction(
|
| 227 |
+
self.model,
|
| 228 |
+
self.tokenizer,
|
| 229 |
+
input_ids,
|
| 230 |
+
attention_mask,
|
| 231 |
+
self.generation_config,
|
| 232 |
+
)
|
| 233 |
+
corrected_chunks.append(chunk_output)
|
| 234 |
+
logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())} → {len(chunk_output.split())} words")
|
| 235 |
+
|
| 236 |
+
generated = " ".join(corrected_chunks)
|
| 237 |
+
|
| 238 |
+
# Step 4: Post-process
|
| 239 |
+
logger.info("Step 4: Post-processing...")
|
| 240 |
+
generated = self.postprocessor.clean(generated)
|
| 241 |
+
generated = self.postprocessor.restore_entities(
|
| 242 |
+
generated,
|
| 243 |
+
[e.text for e in doc.entities],
|
| 244 |
+
doc.protected_spans,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Step 5: Vocabulary elevation
|
| 248 |
+
logger.info("Step 5: Vocabulary elevation...")
|
| 249 |
+
if self.elevator:
|
| 250 |
+
try:
|
| 251 |
+
generated = self.elevator.elevate(generated, doc.protected_spans)
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.warning(f"Vocabulary elevation failed: {e}")
|
| 254 |
+
|
| 255 |
+
# Step 6: Register filter
|
| 256 |
+
logger.info("Step 6: Register filtering...")
|
| 257 |
+
generated = self.register_filter.apply(generated)
|
| 258 |
+
|
| 259 |
+
# Final formatting
|
| 260 |
+
generated = self.postprocessor.format_output(generated)
|
| 261 |
+
|
| 262 |
+
# Step 7: Compute quality metrics
|
| 263 |
+
logger.info("Step 7: Computing metrics...")
|
| 264 |
+
style_sim = self.evaluator.style_similarity(raw_text, generated)
|
| 265 |
+
awl_cov = self.evaluator.awl_coverage(generated)
|
| 266 |
+
|
| 267 |
+
# Build changes summary
|
| 268 |
+
changes = []
|
| 269 |
+
if doc.original_text != doc.corrected_text:
|
| 270 |
+
changes.append("Spelling/grammar corrections applied")
|
| 271 |
+
if generated != doc.corrected_text:
|
| 272 |
+
changes.append("Text restructured and elevated")
|
| 273 |
+
changes_summary = "; ".join(changes) if changes else "No changes needed"
|
| 274 |
+
|
| 275 |
+
return CorrectionResult(
|
| 276 |
+
original=raw_text,
|
| 277 |
+
corrected=generated,
|
| 278 |
+
preprocessed=doc.corrected_text,
|
| 279 |
+
style_similarity=style_sim,
|
| 280 |
+
awl_coverage=awl_cov,
|
| 281 |
+
readability=doc.readability,
|
| 282 |
+
changes_summary=changes_summary,
|
| 283 |
+
)
|
src/inference/postprocessor.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Post-processing utilities for generated text.
|
| 3 |
+
Handles cleanup, formatting, and final quality checks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PostProcessor:
|
| 12 |
+
"""Cleans and formats generated text after model output."""
|
| 13 |
+
|
| 14 |
+
# Common generation artifacts to remove
|
| 15 |
+
ARTIFACTS = [
|
| 16 |
+
r'<pad>',
|
| 17 |
+
r'</s>',
|
| 18 |
+
r'<s>',
|
| 19 |
+
r'<unk>',
|
| 20 |
+
r'\[PAD\]',
|
| 21 |
+
r'\[CLS\]',
|
| 22 |
+
r'\[SEP\]',
|
| 23 |
+
r'<\|endoftext\|>',
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
# Compile artifact removal regex
|
| 28 |
+
self._artifact_pattern = re.compile(
|
| 29 |
+
'|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS),
|
| 30 |
+
re.IGNORECASE
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def clean(self, text: str) -> str:
|
| 34 |
+
"""Remove generation artifacts and normalise whitespace."""
|
| 35 |
+
if not text:
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
# Remove generation artifacts
|
| 39 |
+
result = self._artifact_pattern.sub('', text)
|
| 40 |
+
|
| 41 |
+
# Replace em dashes and en dashes with commas
|
| 42 |
+
result = result.replace('—', ',')
|
| 43 |
+
result = result.replace('–', ',')
|
| 44 |
+
|
| 45 |
+
# Normalise whitespace
|
| 46 |
+
result = re.sub(r'\s+', ' ', result)
|
| 47 |
+
result = result.strip()
|
| 48 |
+
|
| 49 |
+
# Fix common post-generation spacing issues
|
| 50 |
+
result = re.sub(r'\s+([.,!?;:])', r'\1', result) # Remove space before punctuation
|
| 51 |
+
result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) # Add space after punctuation
|
| 52 |
+
result = re.sub(r'\(\s+', '(', result) # Remove space after opening paren
|
| 53 |
+
result = re.sub(r'\s+\)', ')', result) # Remove space before closing paren
|
| 54 |
+
|
| 55 |
+
# Fix multiple punctuation
|
| 56 |
+
result = re.sub(r'\.{2,}', '.', result)
|
| 57 |
+
result = re.sub(r'\?{2,}', '?', result)
|
| 58 |
+
result = re.sub(r'!{2,}', '!', result)
|
| 59 |
+
|
| 60 |
+
return result
|
| 61 |
+
|
| 62 |
+
def restore_entities(
|
| 63 |
+
self,
|
| 64 |
+
text: str,
|
| 65 |
+
original_entities: List[str],
|
| 66 |
+
protected_spans: List[Tuple[int, int]],
|
| 67 |
+
) -> str:
|
| 68 |
+
"""Restore named entities that may have been altered during generation.
|
| 69 |
+
|
| 70 |
+
Uses fuzzy matching to find where entities should be in the generated text
|
| 71 |
+
and restores the original form.
|
| 72 |
+
"""
|
| 73 |
+
if not original_entities:
|
| 74 |
+
return text
|
| 75 |
+
|
| 76 |
+
result = text
|
| 77 |
+
for entity in original_entities:
|
| 78 |
+
# Check if entity is already present in correct form
|
| 79 |
+
if entity in result:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Try case-insensitive match
|
| 83 |
+
pattern = re.compile(re.escape(entity), re.IGNORECASE)
|
| 84 |
+
if pattern.search(result):
|
| 85 |
+
result = pattern.sub(entity, result, count=1)
|
| 86 |
+
logger.debug(f"Restored entity: {entity}")
|
| 87 |
+
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
def format_output(self, text: str) -> str:
|
| 91 |
+
"""Apply final formatting (capitalisation, punctuation, spacing)."""
|
| 92 |
+
if not text:
|
| 93 |
+
return ""
|
| 94 |
+
|
| 95 |
+
result = text.strip()
|
| 96 |
+
|
| 97 |
+
# Ensure first letter is capitalised
|
| 98 |
+
if result and result[0].islower():
|
| 99 |
+
result = result[0].upper() + result[1:]
|
| 100 |
+
|
| 101 |
+
# Ensure text ends with punctuation
|
| 102 |
+
if result and result[-1] not in '.!?':
|
| 103 |
+
result += '.'
|
| 104 |
+
|
| 105 |
+
# Capitalise after sentence-ending punctuation
|
| 106 |
+
result = re.sub(
|
| 107 |
+
r'([.!?]\s+)([a-z])',
|
| 108 |
+
lambda m: m.group(1) + m.group(2).upper(),
|
| 109 |
+
result
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Fix "i" → "I" when standalone
|
| 113 |
+
result = re.sub(r'\bi\b', 'I', result)
|
| 114 |
+
|
| 115 |
+
# Remove trailing whitespace from lines
|
| 116 |
+
result = '\n'.join(line.rstrip() for line in result.split('\n'))
|
| 117 |
+
|
| 118 |
+
return result
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
src/model/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
src/model/__pycache__/base_model.cpython-312.pyc
ADDED
|
Binary file (5.88 kB). View file
|
|
|
src/model/__pycache__/base_model.cpython-314.pyc
ADDED
|
Binary file (6.22 kB). View file
|
|
|
src/model/__pycache__/generation_utils.cpython-312.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
src/model/__pycache__/generation_utils.cpython-314.pyc
ADDED
|
Binary file (4.81 kB). View file
|
|
|
src/model/__pycache__/lora_adapter.cpython-312.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
src/model/__pycache__/style_conditioner.cpython-312.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
src/model/__pycache__/style_conditioner.cpython-314.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
src/model/base_model.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads and wraps the base pretrained model.
|
| 3 |
+
Supported architectures:
|
| 4 |
+
- google/flan-t5-xl (recommended, 3B)
|
| 5 |
+
- google/flan-t5-large (780M, resource-constrained)
|
| 6 |
+
- facebook/bart-large (400M, excellent denoiser)
|
| 7 |
+
- meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoTokenizer, AutoModelForSeq2SeqLM,
|
| 12 |
+
AutoModelForCausalLM, BitsAndBytesConfig
|
| 13 |
+
)
|
| 14 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 15 |
+
import torch
|
| 16 |
+
from loguru import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
ENCODER_DECODER_MODELS = {
|
| 20 |
+
"flan-t5-xl": "google/flan-t5-xl",
|
| 21 |
+
"flan-t5-large": "google/flan-t5-large",
|
| 22 |
+
"flan-t5-base": "google/flan-t5-base",
|
| 23 |
+
"flan-t5-small": "google/flan-t5-small",
|
| 24 |
+
"bart-large": "facebook/bart-large",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
DECODER_ONLY_MODELS = {
|
| 28 |
+
"llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True,
|
| 33 |
+
lora_config_dict: dict = None):
|
| 34 |
+
"""
|
| 35 |
+
Load a pretrained model with optional LoRA and quantization.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS
|
| 39 |
+
quantize: Whether to use 4-bit quantization
|
| 40 |
+
use_lora: Whether to apply LoRA adapters
|
| 41 |
+
lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tuple of (model, tokenizer, is_seq2seq)
|
| 45 |
+
"""
|
| 46 |
+
# Determine model type and HuggingFace identifier
|
| 47 |
+
is_seq2seq = model_key in ENCODER_DECODER_MODELS
|
| 48 |
+
is_causal = model_key in DECODER_ONLY_MODELS
|
| 49 |
+
|
| 50 |
+
if not is_seq2seq and not is_causal:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"Unknown model key: '{model_key}'. "
|
| 53 |
+
f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key)
|
| 57 |
+
logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})")
|
| 58 |
+
|
| 59 |
+
# Load tokenizer
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 61 |
+
if tokenizer.pad_token is None:
|
| 62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 63 |
+
|
| 64 |
+
# Configure quantization if requested
|
| 65 |
+
model_kwargs = {
|
| 66 |
+
"torch_dtype": torch.float32, # CPU-optimised: use float32 for stability
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Detect device
|
| 70 |
+
device = "cpu"
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
device = "cuda"
|
| 73 |
+
# Use bfloat16 if Ampere+, else float16
|
| 74 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 75 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 76 |
+
else:
|
| 77 |
+
model_kwargs["torch_dtype"] = torch.float16
|
| 78 |
+
|
| 79 |
+
if quantize and device == "cuda":
|
| 80 |
+
bnb_config = BitsAndBytesConfig(
|
| 81 |
+
load_in_4bit=True,
|
| 82 |
+
bnb_4bit_quant_type="nf4",
|
| 83 |
+
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
| 84 |
+
bnb_4bit_use_double_quant=True,
|
| 85 |
+
)
|
| 86 |
+
model_kwargs["quantization_config"] = bnb_config
|
| 87 |
+
logger.info("Using 4-bit NF4 quantization")
|
| 88 |
+
elif quantize and device == "cpu":
|
| 89 |
+
logger.warning("Quantization requested but no GPU available, skipping")
|
| 90 |
+
|
| 91 |
+
# Load model
|
| 92 |
+
if is_seq2seq:
|
| 93 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
|
| 94 |
+
else:
|
| 95 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
| 96 |
+
|
| 97 |
+
# Move to device if not quantized (quantized models are already on device)
|
| 98 |
+
if not quantize or device == "cpu":
|
| 99 |
+
model = model.to(device)
|
| 100 |
+
|
| 101 |
+
logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}")
|
| 102 |
+
|
| 103 |
+
# Apply LoRA if requested
|
| 104 |
+
if use_lora:
|
| 105 |
+
lora_cfg = lora_config_dict or {}
|
| 106 |
+
task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM
|
| 107 |
+
|
| 108 |
+
# Default target modules based on model architecture
|
| 109 |
+
default_targets = {
|
| 110 |
+
"flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
|
| 111 |
+
"flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
|
| 112 |
+
"flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
|
| 113 |
+
"flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
|
| 114 |
+
"bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"],
|
| 115 |
+
"llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
lora_config = LoraConfig(
|
| 119 |
+
task_type=task_type,
|
| 120 |
+
r=lora_cfg.get("r", 16),
|
| 121 |
+
lora_alpha=lora_cfg.get("lora_alpha", 32),
|
| 122 |
+
lora_dropout=lora_cfg.get("lora_dropout", 0.05),
|
| 123 |
+
target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])),
|
| 124 |
+
bias="none",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
model = get_peft_model(model, lora_config)
|
| 128 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 129 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 130 |
+
logger.info(
|
| 131 |
+
f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total "
|
| 132 |
+
f"({100 * trainable_params / total_params:.2f}%)"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return model, tokenizer, is_seq2seq
|
src/model/generation_utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation utilities for text correction.
|
| 3 |
+
Handles beam search, constrained decoding, and post-generation cleanup.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 8 |
+
from typing import Dict, Optional, List
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_correction(
|
| 13 |
+
model: PreTrainedModel,
|
| 14 |
+
tokenizer: PreTrainedTokenizer,
|
| 15 |
+
input_ids: torch.Tensor,
|
| 16 |
+
attention_mask: torch.Tensor,
|
| 17 |
+
generation_config: Dict,
|
| 18 |
+
) -> str:
|
| 19 |
+
"""Generate corrected text from input tokens."""
|
| 20 |
+
# Build generation kwargs from config
|
| 21 |
+
gen_kwargs = {
|
| 22 |
+
"input_ids": input_ids,
|
| 23 |
+
"attention_mask": attention_mask,
|
| 24 |
+
"max_new_tokens": generation_config.get("max_new_tokens", 512),
|
| 25 |
+
"num_beams": generation_config.get("num_beams", 5),
|
| 26 |
+
"length_penalty": generation_config.get("length_penalty", 1.0),
|
| 27 |
+
"no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
|
| 28 |
+
"min_length": generation_config.get("min_length", 10),
|
| 29 |
+
"early_stopping": generation_config.get("early_stopping", True),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Optional sampling parameters
|
| 33 |
+
if generation_config.get("do_sample", False):
|
| 34 |
+
gen_kwargs["do_sample"] = True
|
| 35 |
+
gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
|
| 36 |
+
gen_kwargs["top_p"] = generation_config.get("top_p", 0.9)
|
| 37 |
+
else:
|
| 38 |
+
gen_kwargs["do_sample"] = False
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
output_ids = model.generate(**gen_kwargs)
|
| 42 |
+
|
| 43 |
+
# Decode, skipping special tokens
|
| 44 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 45 |
+
return generated_text.strip()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def batch_generate(
|
| 49 |
+
model: PreTrainedModel,
|
| 50 |
+
tokenizer: PreTrainedTokenizer,
|
| 51 |
+
texts: List[str],
|
| 52 |
+
generation_config: Dict,
|
| 53 |
+
max_length: int = 512,
|
| 54 |
+
) -> List[str]:
|
| 55 |
+
"""Generate corrections for a batch of texts."""
|
| 56 |
+
if not texts:
|
| 57 |
+
return []
|
| 58 |
+
|
| 59 |
+
results = []
|
| 60 |
+
# Process in mini-batches to manage memory on CPU
|
| 61 |
+
batch_size = generation_config.get("batch_size", 4)
|
| 62 |
+
|
| 63 |
+
for i in range(0, len(texts), batch_size):
|
| 64 |
+
batch_texts = texts[i:i + batch_size]
|
| 65 |
+
|
| 66 |
+
# Tokenise batch
|
| 67 |
+
inputs = tokenizer(
|
| 68 |
+
batch_texts,
|
| 69 |
+
max_length=max_length,
|
| 70 |
+
padding=True,
|
| 71 |
+
truncation=True,
|
| 72 |
+
return_tensors="pt",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Move to model device
|
| 76 |
+
device = next(model.parameters()).device
|
| 77 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 78 |
+
|
| 79 |
+
# Generate
|
| 80 |
+
gen_kwargs = {
|
| 81 |
+
"max_new_tokens": generation_config.get("max_new_tokens", 512),
|
| 82 |
+
"num_beams": generation_config.get("num_beams", 5),
|
| 83 |
+
"length_penalty": generation_config.get("length_penalty", 1.0),
|
| 84 |
+
"no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
|
| 85 |
+
"early_stopping": generation_config.get("early_stopping", True),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if generation_config.get("do_sample", False):
|
| 89 |
+
gen_kwargs["do_sample"] = True
|
| 90 |
+
gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
output_ids = model.generate(
|
| 94 |
+
input_ids=inputs["input_ids"],
|
| 95 |
+
attention_mask=inputs["attention_mask"],
|
| 96 |
+
**gen_kwargs,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Decode each output
|
| 100 |
+
for output in output_ids:
|
| 101 |
+
text = tokenizer.decode(output, skip_special_tokens=True)
|
| 102 |
+
results.append(text.strip())
|
| 103 |
+
|
| 104 |
+
logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts")
|
| 105 |
+
|
| 106 |
+
return results
|
src/model/lora_adapter.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA adapter configuration and management.
|
| 3 |
+
Wraps PEFT LoRA utilities for applying parameter-efficient
|
| 4 |
+
fine-tuning to the base model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_lora_config(
|
| 13 |
+
task_type: TaskType,
|
| 14 |
+
r: int = 16,
|
| 15 |
+
lora_alpha: int = 32,
|
| 16 |
+
target_modules: Optional[List[str]] = None,
|
| 17 |
+
lora_dropout: float = 0.05,
|
| 18 |
+
) -> LoraConfig:
|
| 19 |
+
"""Create a LoRA configuration for the given task type."""
|
| 20 |
+
if target_modules is None:
|
| 21 |
+
target_modules = ["q", "v"]
|
| 22 |
+
|
| 23 |
+
config = LoraConfig(
|
| 24 |
+
task_type=task_type,
|
| 25 |
+
r=r,
|
| 26 |
+
lora_alpha=lora_alpha,
|
| 27 |
+
lora_dropout=lora_dropout,
|
| 28 |
+
target_modules=target_modules,
|
| 29 |
+
bias="none",
|
| 30 |
+
inference_mode=False,
|
| 31 |
+
)
|
| 32 |
+
logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}")
|
| 33 |
+
return config
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def apply_lora(model, lora_config: LoraConfig):
|
| 37 |
+
"""Apply LoRA adapters to a model and return the wrapped model."""
|
| 38 |
+
peft_model = get_peft_model(model, lora_config)
|
| 39 |
+
trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
|
| 40 |
+
total = sum(p.numel() for p in peft_model.parameters())
|
| 41 |
+
logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)")
|
| 42 |
+
return peft_model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def merge_lora_weights(model):
|
| 46 |
+
"""Merge LoRA weights into the base model for inference.
|
| 47 |
+
|
| 48 |
+
After merging, the model behaves like a regular model with
|
| 49 |
+
LoRA modifications baked in, removing the adapter overhead.
|
| 50 |
+
"""
|
| 51 |
+
logger.info("Merging LoRA weights into base model...")
|
| 52 |
+
merged = model.merge_and_unload()
|
| 53 |
+
logger.info("LoRA weights merged successfully")
|
| 54 |
+
return merged
|
src/model/style_conditioner.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Injects the style vector into the model via soft prompt conditioning.
|
| 3 |
+
The style vector is projected to the model's hidden dimension and
|
| 4 |
+
prepended to the input token embeddings as virtual tokens.
|
| 5 |
+
|
| 6 |
+
This technique is called "prefix tuning" / "style prefix injection".
|
| 7 |
+
It biases the model's attention toward the desired output style
|
| 8 |
+
without modifying the base model weights.
|
| 9 |
+
|
| 10 |
+
For Flan-T5: injects into encoder input embeddings
|
| 11 |
+
For BART: injects into encoder input embeddings
|
| 12 |
+
For Llama: prepends to the full input context
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StyleConditioner(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Projects a 512-dim style vector to n_prefix_tokens virtual tokens
|
| 22 |
+
in the model's embedding space.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
style_dim: int = 512,
|
| 28 |
+
model_hidden_dim: int = 512, # T5-Small=512, Base=768, Large=1024, XL=2048
|
| 29 |
+
n_prefix_tokens: int = 10, # Number of virtual prefix tokens
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.style_dim = style_dim
|
| 33 |
+
self.model_hidden_dim = model_hidden_dim
|
| 34 |
+
self.n_prefix_tokens = n_prefix_tokens
|
| 35 |
+
|
| 36 |
+
# Project style vector to prefix embeddings
|
| 37 |
+
# style_dim → n_prefix_tokens * model_hidden_dim
|
| 38 |
+
total_output_dim = n_prefix_tokens * model_hidden_dim
|
| 39 |
+
self.projection = nn.Sequential(
|
| 40 |
+
nn.Linear(style_dim, total_output_dim),
|
| 41 |
+
nn.Tanh(),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, style_vector: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Args:
|
| 47 |
+
style_vector: [batch_size, 512]
|
| 48 |
+
Returns:
|
| 49 |
+
prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim]
|
| 50 |
+
"""
|
| 51 |
+
# Project: [batch, 512] → [batch, n_prefix * hidden_dim]
|
| 52 |
+
projected = self.projection(style_vector)
|
| 53 |
+
|
| 54 |
+
# Reshape: [batch, n_prefix * hidden_dim] → [batch, n_prefix, hidden_dim]
|
| 55 |
+
batch_size = style_vector.size(0)
|
| 56 |
+
prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim)
|
| 57 |
+
|
| 58 |
+
return prefix_embeddings
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prepend_style_prefix(
|
| 62 |
+
input_embeddings: torch.Tensor,
|
| 63 |
+
style_prefix: torch.Tensor,
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Concatenates style prefix to input embeddings along sequence dimension.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
input_embeddings: [batch, seq_len, hidden_dim]
|
| 70 |
+
style_prefix: [batch, n_prefix, hidden_dim]
|
| 71 |
+
Returns:
|
| 72 |
+
[batch, n_prefix + seq_len, hidden_dim]
|
| 73 |
+
"""
|
| 74 |
+
return torch.cat([style_prefix, input_embeddings], dim=1)
|
src/preprocessing/__init__.py
ADDED
|
File without changes
|
src/preprocessing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
src/preprocessing/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc
ADDED
|
Binary file (6.75 kB). View file
|
|
|
src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc
ADDED
|
Binary file (2.7 kB). View file
|
|
|