morpheuslord commited on
Commit
12fd5f2
·
verified ·
1 Parent(s): 9dd64b9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. scripts/__pycache__/evaluate.cpython-312.pyc +0 -0
  2. src/__pycache__/__init__.cpython-312.pyc +0 -0
  3. src/__pycache__/__init__.cpython-314.pyc +0 -0
  4. src/api/__init__.py +0 -0
  5. src/api/__pycache__/main.cpython-312.pyc +0 -0
  6. src/api/__pycache__/middleware.cpython-312.pyc +0 -0
  7. src/api/__pycache__/schemas.cpython-312.pyc +0 -0
  8. src/api/middleware.py +67 -0
  9. src/api/schemas.py +21 -0
  10. src/evaluation/__init__.py +0 -0
  11. src/evaluation/__pycache__/__init__.cpython-314.pyc +0 -0
  12. src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc +0 -0
  13. src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc +0 -0
  14. src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc +0 -0
  15. src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc +0 -0
  16. src/evaluation/__pycache__/style_metrics.cpython-312.pyc +0 -0
  17. src/evaluation/__pycache__/style_metrics.cpython-314.pyc +0 -0
  18. src/evaluation/authorship_verifier.py +50 -0
  19. src/evaluation/errant_evaluator.py +82 -0
  20. src/evaluation/gleu_scorer.py +68 -0
  21. src/evaluation/style_metrics.py +81 -0
  22. src/inference/__init__.py +0 -0
  23. src/inference/__pycache__/__init__.cpython-314.pyc +0 -0
  24. src/inference/__pycache__/corrector.cpython-312.pyc +0 -0
  25. src/inference/__pycache__/corrector.cpython-314.pyc +0 -0
  26. src/inference/__pycache__/postprocessor.cpython-312.pyc +0 -0
  27. src/inference/__pycache__/postprocessor.cpython-314.pyc +0 -0
  28. src/inference/corrector.py +283 -0
  29. src/inference/postprocessor.py +118 -0
  30. src/model/__init__.py +0 -0
  31. src/model/__pycache__/__init__.cpython-312.pyc +0 -0
  32. src/model/__pycache__/__init__.cpython-314.pyc +0 -0
  33. src/model/__pycache__/base_model.cpython-312.pyc +0 -0
  34. src/model/__pycache__/base_model.cpython-314.pyc +0 -0
  35. src/model/__pycache__/generation_utils.cpython-312.pyc +0 -0
  36. src/model/__pycache__/generation_utils.cpython-314.pyc +0 -0
  37. src/model/__pycache__/lora_adapter.cpython-312.pyc +0 -0
  38. src/model/__pycache__/style_conditioner.cpython-312.pyc +0 -0
  39. src/model/__pycache__/style_conditioner.cpython-314.pyc +0 -0
  40. src/model/base_model.py +135 -0
  41. src/model/generation_utils.py +106 -0
  42. src/model/lora_adapter.py +54 -0
  43. src/model/style_conditioner.py +74 -0
  44. src/preprocessing/__init__.py +0 -0
  45. src/preprocessing/__pycache__/__init__.cpython-312.pyc +0 -0
  46. src/preprocessing/__pycache__/__init__.cpython-314.pyc +0 -0
  47. src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc +0 -0
  48. src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc +0 -0
  49. src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc +0 -0
  50. src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc +0 -0
scripts/__pycache__/evaluate.cpython-312.pyc ADDED
Binary file (5.24 kB). View file
 
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (169 Bytes). View file
 
src/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (171 Bytes). View file
 
src/api/__init__.py ADDED
File without changes
src/api/__pycache__/main.cpython-312.pyc ADDED
Binary file (3.59 kB). View file
 
src/api/__pycache__/middleware.cpython-312.pyc ADDED
Binary file (3.68 kB). View file
 
src/api/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (1.38 kB). View file
 
src/api/middleware.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API middleware for request logging, rate limiting, and error handling.
3
+ """
4
+
5
+ from fastapi import Request
6
+ from fastapi.responses import JSONResponse
7
+ from starlette.middleware.base import BaseHTTPMiddleware
8
+ from loguru import logger
9
+ import time
10
+ from collections import defaultdict, deque
11
+
12
+
13
+ class RequestLoggingMiddleware(BaseHTTPMiddleware):
14
+ """Logs all incoming requests with timing information."""
15
+
16
+ async def dispatch(self, request: Request, call_next):
17
+ start_time = time.time()
18
+ path = request.url.path
19
+ method = request.method
20
+
21
+ logger.info(f"→ {method} {path}")
22
+
23
+ try:
24
+ response = await call_next(request)
25
+ except Exception as e:
26
+ logger.error(f"✗ {method} {path} - Error: {e}")
27
+ raise
28
+
29
+ elapsed = (time.time() - start_time) * 1000 # ms
30
+ logger.info(f"← {method} {path} - {response.status_code} ({elapsed:.1f}ms)")
31
+
32
+ return response
33
+
34
+
35
+ class RateLimitMiddleware(BaseHTTPMiddleware):
36
+ """Simple in-memory rate limiting."""
37
+
38
+ def __init__(self, app, max_requests_per_minute: int = 60):
39
+ super().__init__(app)
40
+ self.max_requests = max_requests_per_minute
41
+ self.window = 60 # seconds
42
+ # Track requests per client IP: {ip: deque([timestamp, ...])}
43
+ self.requests: dict = defaultdict(deque)
44
+
45
+ async def dispatch(self, request: Request, call_next):
46
+ # Get client IP
47
+ client_ip = request.client.host if request.client else "unknown"
48
+ now = time.time()
49
+
50
+ # Clean old entries
51
+ timestamps = self.requests[client_ip]
52
+ while timestamps and timestamps[0] < now - self.window:
53
+ timestamps.popleft()
54
+
55
+ # Check rate limit
56
+ if len(timestamps) >= self.max_requests:
57
+ logger.warning(f"Rate limited: {client_ip} ({len(timestamps)} requests in {self.window}s)")
58
+ return JSONResponse(
59
+ status_code=429,
60
+ content={"detail": "Rate limit exceeded. Please wait before making more requests."},
61
+ )
62
+
63
+ # Record this request
64
+ timestamps.append(now)
65
+
66
+ response = await call_next(request)
67
+ return response
src/api/schemas.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for API request/response validation.
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional, Dict
7
+
8
+
9
+ class CorrectionRequest(BaseModel):
10
+ text: str = Field(..., min_length=10, max_length=5000, description="Raw dyslectic text to correct.")
11
+ master_copy: Optional[str] = Field(None, description="Optional master copy to match style toward.")
12
+ style_alpha: float = Field(0.6, ge=0.0, le=1.0, description="Weight given to user's own style (0=full master, 1=full user).")
13
+
14
+
15
+ class CorrectionResponse(BaseModel):
16
+ original: str
17
+ corrected: str
18
+ style_similarity: float
19
+ awl_coverage: float
20
+ readability: Dict[str, float]
21
+ changes_summary: str
src/evaluation/__init__.py ADDED
File without changes
src/evaluation/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (182 Bytes). View file
 
src/evaluation/__pycache__/authorship_verifier.cpython-312.pyc ADDED
Binary file (2.6 kB). View file
 
src/evaluation/__pycache__/errant_evaluator.cpython-312.pyc ADDED
Binary file (3.63 kB). View file
 
src/evaluation/__pycache__/gleu_scorer.cpython-312.pyc ADDED
Binary file (2.42 kB). View file
 
src/evaluation/__pycache__/gleu_scorer.cpython-314.pyc ADDED
Binary file (3.02 kB). View file
 
src/evaluation/__pycache__/style_metrics.cpython-312.pyc ADDED
Binary file (4.34 kB). View file
 
src/evaluation/__pycache__/style_metrics.cpython-314.pyc ADDED
Binary file (5.25 kB). View file
 
src/evaluation/authorship_verifier.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authorship verification module.
3
+ Uses a fine-tuned model to verify whether the corrected output
4
+ could plausibly have been written by the same author as the input.
5
+ Target: > 0.80 same-author probability.
6
+ """
7
+
8
+ from typing import Tuple
9
+ from loguru import logger
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class AuthorshipVerifier:
15
+ """Verifies authorship consistency between input and output text."""
16
+
17
+ def __init__(self, model_name: str = "roberta-base"):
18
+ try:
19
+ from sentence_transformers import SentenceTransformer
20
+ self.model = SentenceTransformer(model_name, device="cpu")
21
+ logger.info(f"AuthorshipVerifier loaded with {model_name}")
22
+ except Exception as e:
23
+ logger.warning(f"Failed to load authorship model: {e}")
24
+ self.model = None
25
+
26
+ def verify(self, text_a: str, text_b: str) -> float:
27
+ """Return probability that both texts were written by the same author.
28
+
29
+ Uses sentence embedding similarity as a proxy for authorship.
30
+ Higher cosine similarity suggests same author.
31
+ """
32
+ if self.model is None:
33
+ return 0.5 # Neutral score if model unavailable
34
+
35
+ if not text_a or not text_b:
36
+ return 0.5
37
+
38
+ try:
39
+ embeddings = self.model.encode([text_a, text_b], convert_to_tensor=True)
40
+ sim = F.cosine_similarity(
41
+ embeddings[0].unsqueeze(0),
42
+ embeddings[1].unsqueeze(0),
43
+ )
44
+ # Scale similarity to [0, 1] probability
45
+ # Cosine similarity is already in [-1, 1], shift to [0, 1]
46
+ prob = (sim.item() + 1.0) / 2.0
47
+ return prob
48
+ except Exception as e:
49
+ logger.warning(f"Authorship verification failed: {e}")
50
+ return 0.5
src/evaluation/errant_evaluator.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ERRANT-based grammatical error evaluation.
3
+ Uses the ERRANT toolkit for standardised GEC evaluation with
4
+ precision, recall, and F0.5 scores.
5
+ """
6
+
7
+ from typing import List, Dict
8
+ from loguru import logger
9
+
10
+
11
+ class ERRANTEvaluator:
12
+ """Evaluates grammar correction quality using ERRANT annotations."""
13
+
14
+ def __init__(self):
15
+ try:
16
+ import errant
17
+ self.annotator = errant.load("en")
18
+ logger.info("ERRANT annotator loaded")
19
+ except Exception as e:
20
+ logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.")
21
+ self.annotator = None
22
+
23
+ def evaluate(
24
+ self,
25
+ sources: List[str],
26
+ predictions: List[str],
27
+ references: List[str],
28
+ ) -> Dict[str, float]:
29
+ """Compute ERRANT precision, recall, F0.5."""
30
+ if self.annotator is None:
31
+ logger.warning("ERRANT not available, returning zero scores")
32
+ return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0}
33
+
34
+ tp = 0
35
+ fp = 0
36
+ fn = 0
37
+
38
+ for src, pred, ref in zip(sources, predictions, references):
39
+ try:
40
+ # Parse source and annotate edits
41
+ orig = self.annotator.parse(src)
42
+ cor_pred = self.annotator.parse(pred)
43
+ cor_ref = self.annotator.parse(ref)
44
+
45
+ # Get edit annotations
46
+ pred_edits = self.annotator.annotate(orig, cor_pred)
47
+ ref_edits = self.annotator.annotate(orig, cor_ref)
48
+
49
+ # Convert to comparable sets of (start, end, correction, type)
50
+ pred_set = set()
51
+ for e in pred_edits:
52
+ pred_set.add((e.o_start, e.o_end, e.c_str, e.type))
53
+
54
+ ref_set = set()
55
+ for e in ref_edits:
56
+ ref_set.add((e.o_start, e.o_end, e.c_str, e.type))
57
+
58
+ # Count TP, FP, FN
59
+ tp += len(pred_set & ref_set)
60
+ fp += len(pred_set - ref_set)
61
+ fn += len(ref_set - pred_set)
62
+
63
+ except Exception as e:
64
+ logger.debug(f"ERRANT annotation failed for a sample: {e}")
65
+ continue
66
+
67
+ # Compute precision, recall, F0.5
68
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
69
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
70
+
71
+ # F0.5 weighs precision higher than recall (β=0.5)
72
+ beta = 0.5
73
+ if precision + recall > 0:
74
+ f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
75
+ else:
76
+ f_score = 0.0
77
+
78
+ return {
79
+ "precision": precision,
80
+ "recall": recall,
81
+ "f0.5": f_score,
82
+ }
src/evaluation/gleu_scorer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GLEU (Generalized Language Evaluation Understanding) score.
3
+ Preferred over BLEU for grammatical error correction tasks.
4
+ Also computes BERTScore for semantic similarity evaluation.
5
+ """
6
+
7
+ import sacrebleu
8
+ from bert_score import score as bert_score_fn
9
+ from typing import List, Tuple
10
+ from loguru import logger
11
+
12
+
13
+ class GLEUScorer:
14
+ """Computes GLEU and BERTScore metrics for GEC evaluation."""
15
+
16
+ def compute_gleu(
17
+ self,
18
+ predictions: List[str],
19
+ references: List[str],
20
+ ) -> float:
21
+ """Corpus-level GLEU score (0-100).
22
+
23
+ GLEU is the geometric mean of n-gram precisions and recall,
24
+ preferred over BLEU for GEC because it equally penalises
25
+ both under-correction and over-correction.
26
+ """
27
+ if not predictions or not references:
28
+ return 0.0
29
+
30
+ # sacrebleu expects references as a list of lists
31
+ refs = [references]
32
+
33
+ # Use BLEU with smoothing as GLEU approximation
34
+ # sacrebleu doesn't have a native GLEU, so we use smoothed BLEU
35
+ bleu = sacrebleu.corpus_bleu(
36
+ predictions,
37
+ refs,
38
+ smooth_method="exp",
39
+ smooth_value=0.1,
40
+ )
41
+ return bleu.score
42
+
43
+ def compute_bert_score(
44
+ self,
45
+ predictions: List[str],
46
+ references: List[str],
47
+ lang: str = "en",
48
+ ) -> Tuple[float, float, float]:
49
+ """Returns (precision, recall, F1) as averages over the batch."""
50
+ if not predictions or not references:
51
+ return (0.0, 0.0, 0.0)
52
+
53
+ try:
54
+ P, R, F1 = bert_score_fn(
55
+ predictions,
56
+ references,
57
+ lang=lang,
58
+ verbose=False,
59
+ device="cpu", # CPU-optimised
60
+ )
61
+ return (
62
+ P.mean().item(),
63
+ R.mean().item(),
64
+ F1.mean().item(),
65
+ )
66
+ except Exception as e:
67
+ logger.warning(f"BERTScore computation failed: {e}")
68
+ return (0.0, 0.0, 0.0)
src/evaluation/style_metrics.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Measures style preservation between input and output.
3
+
4
+ Key metrics:
5
+ - Style Vector Cosine Similarity (target: > 0.85)
6
+ - AWL Coverage Score (target: > 0.25)
7
+ - Authorship Verification Score (target: > 0.80)
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from typing import List, Tuple
13
+ from ..style.fingerprinter import StyleFingerprinter
14
+ from ..vocabulary.awl_loader import AWLLoader
15
+ from loguru import logger
16
+ import numpy as np
17
+
18
+
19
+ class StyleEvaluator:
20
+ """Evaluates style preservation and academic vocabulary coverage."""
21
+
22
+ def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader):
23
+ self.fingerprinter = fingerprinter
24
+ self.awl = awl
25
+
26
+ def style_similarity(self, text_a: str, text_b: str) -> float:
27
+ """Cosine similarity between style vectors. Target: > 0.85."""
28
+ vec_a = self.fingerprinter.extract_vector(text_a)
29
+ vec_b = self.fingerprinter.extract_vector(text_b)
30
+
31
+ if vec_a.dim() == 1:
32
+ vec_a = vec_a.unsqueeze(0)
33
+ if vec_b.dim() == 1:
34
+ vec_b = vec_b.unsqueeze(0)
35
+
36
+ sim = F.cosine_similarity(vec_a, vec_b, dim=-1)
37
+ return sim.item()
38
+
39
+ def awl_coverage(self, text: str) -> float:
40
+ """Fraction of content words in AWL. Target: > 0.25."""
41
+ if not text or not text.strip():
42
+ return 0.0
43
+
44
+ words = text.lower().split()
45
+ # Filter to content words (longer than 3 chars, alphabetic)
46
+ content_words = [w for w in words if len(w) > 3 and w.isalpha()]
47
+
48
+ if not content_words:
49
+ return 0.0
50
+
51
+ awl_count = sum(1 for w in content_words if self.awl.is_academic(w))
52
+ return awl_count / len(content_words)
53
+
54
+ def evaluate_batch(
55
+ self,
56
+ inputs: List[str],
57
+ outputs: List[str],
58
+ references: List[str],
59
+ ) -> dict:
60
+ """Compute style and AWL metrics for a batch."""
61
+ style_sims = []
62
+ awl_coverages = []
63
+ ref_style_sims = []
64
+
65
+ for inp, out, ref in zip(inputs, outputs, references):
66
+ # Style similarity between input and output (preservation)
67
+ style_sims.append(self.style_similarity(inp, out))
68
+
69
+ # AWL coverage of output
70
+ awl_coverages.append(self.awl_coverage(out))
71
+
72
+ # Style similarity between output and reference
73
+ ref_style_sims.append(self.style_similarity(out, ref))
74
+
75
+ return {
76
+ "style_similarity_mean": float(np.mean(style_sims)),
77
+ "style_similarity_std": float(np.std(style_sims)),
78
+ "awl_coverage_mean": float(np.mean(awl_coverages)),
79
+ "awl_coverage_std": float(np.std(awl_coverages)),
80
+ "ref_style_similarity_mean": float(np.mean(ref_style_sims)),
81
+ }
src/inference/__init__.py ADDED
File without changes
src/inference/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (181 Bytes). View file
 
src/inference/__pycache__/corrector.cpython-312.pyc ADDED
Binary file (9.52 kB). View file
 
src/inference/__pycache__/corrector.cpython-314.pyc ADDED
Binary file (13.7 kB). View file
 
src/inference/__pycache__/postprocessor.cpython-312.pyc ADDED
Binary file (4.72 kB). View file
 
src/inference/__pycache__/postprocessor.cpython-314.pyc ADDED
Binary file (5.69 kB). View file
 
src/inference/corrector.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end inference pipeline.
3
+ Accepts raw dyslectic text (and optionally a master copy),
4
+ returns corrected academic text with metadata.
5
+ """
6
+
7
+ from ..preprocessing.pipeline import PreprocessingPipeline
8
+ from ..style.fingerprinter import StyleFingerprinter
9
+ from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter
10
+ from ..model.base_model import load_model_and_tokenizer
11
+ from ..model.style_conditioner import StyleConditioner, prepend_style_prefix
12
+ from ..model.generation_utils import generate_correction
13
+ from .postprocessor import PostProcessor
14
+ from ..evaluation.style_metrics import StyleEvaluator
15
+ from ..vocabulary.awl_loader import AWLLoader
16
+ import torch
17
+ from typing import Optional
18
+ from dataclasses import dataclass
19
+ from loguru import logger
20
+ import yaml
21
+
22
+
23
+ TASK_PREFIX = (
24
+ "Correct the following text for grammar, spelling, and clarity. "
25
+ "Maintain the author's original tone and writing style. "
26
+ "Elevate vocabulary to academic register. "
27
+ "Do NOT change the meaning or add new information. "
28
+ "Preserve named entities exactly. "
29
+ "Text to correct: "
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class CorrectionResult:
35
+ original: str
36
+ corrected: str
37
+ preprocessed: str
38
+ style_similarity: float
39
+ awl_coverage: float
40
+ readability: dict
41
+ changes_summary: str
42
+
43
+
44
+ class AcademicCorrector:
45
+ """Full inference pipeline: preprocess → fingerprint → generate → elevate → filter."""
46
+
47
+ def __init__(self, config: dict):
48
+ logger.info("Initialising AcademicCorrector...")
49
+
50
+ model_cfg = config.get("model", {})
51
+ gen_cfg = config.get("generation", {})
52
+ vocab_cfg = config.get("vocabulary", {})
53
+ style_cfg = config.get("style_conditioner", {})
54
+
55
+ # 1. Load model and tokenizer
56
+ model_key = model_cfg.get("key", "flan-t5-small")
57
+ checkpoint = model_cfg.get("checkpoint_path", None)
58
+ use_lora = model_cfg.get("use_lora", False)
59
+
60
+ if checkpoint and use_lora:
61
+ # PEFT adapter checkpoint: load base model + apply adapter
62
+ import os
63
+ try:
64
+ from peft import PeftModel
65
+ logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'")
66
+ self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
67
+ model_key, quantize=False, use_lora=False
68
+ )
69
+ self.model = PeftModel.from_pretrained(self.model, checkpoint)
70
+ logger.info(f"PEFT adapter loaded from {checkpoint}")
71
+ except Exception as e:
72
+ logger.warning(f"PEFT loading failed ({e}), loading base model only")
73
+ self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
74
+ model_key, quantize=False, use_lora=False
75
+ )
76
+ elif checkpoint:
77
+ # Full model checkpoint (merged weights)
78
+ try:
79
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
80
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
81
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
82
+ self.is_seq2seq = True
83
+ logger.info(f"Loaded full model from checkpoint: {checkpoint}")
84
+ except Exception:
85
+ logger.warning(f"Checkpoint not found, loading base model: {model_key}")
86
+ self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
87
+ model_key, quantize=False, use_lora=False
88
+ )
89
+ else:
90
+ self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
91
+ model_key, quantize=False, use_lora=False
92
+ )
93
+
94
+ self.model.eval()
95
+ self.generation_config = gen_cfg
96
+
97
+ # 2. Preprocessor
98
+ self.preprocessor = PreprocessingPipeline()
99
+
100
+ # 3. Style fingerprinter
101
+ fp_cfg = config.get("fingerprinter", {})
102
+ self.fingerprinter = StyleFingerprinter(
103
+ spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"),
104
+ awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
105
+ )
106
+
107
+ # 4. Style conditioner — auto-detect hidden dim from loaded model
108
+ if hasattr(self.model.config, "d_model"):
109
+ auto_hidden_dim = self.model.config.d_model
110
+ elif hasattr(self.model.config, "hidden_size"):
111
+ auto_hidden_dim = self.model.config.hidden_size
112
+ else:
113
+ auto_hidden_dim = 512 # Safe default for T5-Small
114
+ logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}")
115
+
116
+ self.conditioner = StyleConditioner(
117
+ style_dim=style_cfg.get("style_dim", 512),
118
+ model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim),
119
+ n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10),
120
+ )
121
+ self.conditioner.eval()
122
+
123
+ # 5. Vocabulary elevator
124
+ try:
125
+ self.elevator = LexicalElevator(
126
+ awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
127
+ spacy_model="en_core_web_sm",
128
+ mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"),
129
+ sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"),
130
+ )
131
+ except Exception as e:
132
+ logger.warning(f"Lexical elevator init failed: {e}, elevation disabled")
133
+ self.elevator = None
134
+
135
+ # 6. Register filter
136
+ self.register_filter = RegisterFilter()
137
+
138
+ # 7. Post-processor
139
+ self.postprocessor = PostProcessor()
140
+
141
+ # 8. Evaluator
142
+ awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"))
143
+ self.evaluator = StyleEvaluator(self.fingerprinter, awl)
144
+
145
+ logger.info("AcademicCorrector initialised successfully")
146
+
147
+ def correct(
148
+ self,
149
+ raw_text: str,
150
+ master_copy: Optional[str] = None,
151
+ style_alpha: float = 0.6,
152
+ ) -> CorrectionResult:
153
+ """
154
+ Full correction pipeline:
155
+ 1. Pre-process (spell correct + parse)
156
+ 2. Style fingerprint
157
+ 3. Generate with style conditioning
158
+ 4. Academic vocabulary elevation
159
+ 5. Register filter
160
+ 6. Compute quality metrics
161
+ """
162
+ # Step 1: Pre-process
163
+ logger.info("Step 1: Preprocessing...")
164
+ doc = self.preprocessor.process(raw_text)
165
+
166
+ # Step 2: Style fingerprint
167
+ logger.info("Step 2: Extracting style fingerprint...")
168
+ user_style = self.fingerprinter.extract_vector(doc.corrected_text)
169
+
170
+ if master_copy:
171
+ master_style = self.fingerprinter.extract_vector(master_copy)
172
+ target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha)
173
+ else:
174
+ target_style = user_style
175
+
176
+ # Step 3: Generate correction (sentence-chunked)
177
+ # The model was trained on max_input_length=128 tokens.
178
+ # Split text into sentence groups that fit within that window,
179
+ # process each chunk, then reassemble.
180
+ logger.info("Step 3: Generating correction (chunked)...")
181
+
182
+ MAX_INPUT_TOKENS = 128
183
+ # Measure how many tokens the task prefix uses
184
+ prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False))
185
+ budget = MAX_INPUT_TOKENS - prefix_tokens - 2 # 2 for special tokens
186
+
187
+ # Split into sentences using spaCy (already loaded for fingerprinting)
188
+ sent_doc = self.fingerprinter.nlp(doc.corrected_text)
189
+ sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()]
190
+
191
+ # Group sentences into chunks that fit the token budget
192
+ chunks = []
193
+ current_chunk = []
194
+ current_tokens = 0
195
+
196
+ for sent in sentences:
197
+ sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False))
198
+ if current_tokens + sent_tokens > budget and current_chunk:
199
+ chunks.append(" ".join(current_chunk))
200
+ current_chunk = [sent]
201
+ current_tokens = sent_tokens
202
+ else:
203
+ current_chunk.append(sent)
204
+ current_tokens += sent_tokens
205
+
206
+ if current_chunk:
207
+ chunks.append(" ".join(current_chunk))
208
+
209
+ logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences")
210
+
211
+ corrected_chunks = []
212
+ device = next(self.model.parameters()).device
213
+
214
+ for i, chunk in enumerate(chunks):
215
+ chunk_input = TASK_PREFIX + chunk
216
+ inputs = self.tokenizer(
217
+ chunk_input,
218
+ max_length=MAX_INPUT_TOKENS,
219
+ truncation=True,
220
+ return_tensors="pt",
221
+ )
222
+
223
+ input_ids = inputs["input_ids"].to(device)
224
+ attention_mask = inputs["attention_mask"].to(device)
225
+
226
+ chunk_output = generate_correction(
227
+ self.model,
228
+ self.tokenizer,
229
+ input_ids,
230
+ attention_mask,
231
+ self.generation_config,
232
+ )
233
+ corrected_chunks.append(chunk_output)
234
+ logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())} → {len(chunk_output.split())} words")
235
+
236
+ generated = " ".join(corrected_chunks)
237
+
238
+ # Step 4: Post-process
239
+ logger.info("Step 4: Post-processing...")
240
+ generated = self.postprocessor.clean(generated)
241
+ generated = self.postprocessor.restore_entities(
242
+ generated,
243
+ [e.text for e in doc.entities],
244
+ doc.protected_spans,
245
+ )
246
+
247
+ # Step 5: Vocabulary elevation
248
+ logger.info("Step 5: Vocabulary elevation...")
249
+ if self.elevator:
250
+ try:
251
+ generated = self.elevator.elevate(generated, doc.protected_spans)
252
+ except Exception as e:
253
+ logger.warning(f"Vocabulary elevation failed: {e}")
254
+
255
+ # Step 6: Register filter
256
+ logger.info("Step 6: Register filtering...")
257
+ generated = self.register_filter.apply(generated)
258
+
259
+ # Final formatting
260
+ generated = self.postprocessor.format_output(generated)
261
+
262
+ # Step 7: Compute quality metrics
263
+ logger.info("Step 7: Computing metrics...")
264
+ style_sim = self.evaluator.style_similarity(raw_text, generated)
265
+ awl_cov = self.evaluator.awl_coverage(generated)
266
+
267
+ # Build changes summary
268
+ changes = []
269
+ if doc.original_text != doc.corrected_text:
270
+ changes.append("Spelling/grammar corrections applied")
271
+ if generated != doc.corrected_text:
272
+ changes.append("Text restructured and elevated")
273
+ changes_summary = "; ".join(changes) if changes else "No changes needed"
274
+
275
+ return CorrectionResult(
276
+ original=raw_text,
277
+ corrected=generated,
278
+ preprocessed=doc.corrected_text,
279
+ style_similarity=style_sim,
280
+ awl_coverage=awl_cov,
281
+ readability=doc.readability,
282
+ changes_summary=changes_summary,
283
+ )
src/inference/postprocessor.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-processing utilities for generated text.
3
+ Handles cleanup, formatting, and final quality checks.
4
+ """
5
+
6
+ import re
7
+ from typing import List, Tuple
8
+ from loguru import logger
9
+
10
+
11
+ class PostProcessor:
12
+ """Cleans and formats generated text after model output."""
13
+
14
+ # Common generation artifacts to remove
15
+ ARTIFACTS = [
16
+ r'<pad>',
17
+ r'</s>',
18
+ r'<s>',
19
+ r'<unk>',
20
+ r'\[PAD\]',
21
+ r'\[CLS\]',
22
+ r'\[SEP\]',
23
+ r'<\|endoftext\|>',
24
+ ]
25
+
26
+ def __init__(self):
27
+ # Compile artifact removal regex
28
+ self._artifact_pattern = re.compile(
29
+ '|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS),
30
+ re.IGNORECASE
31
+ )
32
+
33
+ def clean(self, text: str) -> str:
34
+ """Remove generation artifacts and normalise whitespace."""
35
+ if not text:
36
+ return ""
37
+
38
+ # Remove generation artifacts
39
+ result = self._artifact_pattern.sub('', text)
40
+
41
+ # Replace em dashes and en dashes with commas
42
+ result = result.replace('—', ',')
43
+ result = result.replace('–', ',')
44
+
45
+ # Normalise whitespace
46
+ result = re.sub(r'\s+', ' ', result)
47
+ result = result.strip()
48
+
49
+ # Fix common post-generation spacing issues
50
+ result = re.sub(r'\s+([.,!?;:])', r'\1', result) # Remove space before punctuation
51
+ result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) # Add space after punctuation
52
+ result = re.sub(r'\(\s+', '(', result) # Remove space after opening paren
53
+ result = re.sub(r'\s+\)', ')', result) # Remove space before closing paren
54
+
55
+ # Fix multiple punctuation
56
+ result = re.sub(r'\.{2,}', '.', result)
57
+ result = re.sub(r'\?{2,}', '?', result)
58
+ result = re.sub(r'!{2,}', '!', result)
59
+
60
+ return result
61
+
62
+ def restore_entities(
63
+ self,
64
+ text: str,
65
+ original_entities: List[str],
66
+ protected_spans: List[Tuple[int, int]],
67
+ ) -> str:
68
+ """Restore named entities that may have been altered during generation.
69
+
70
+ Uses fuzzy matching to find where entities should be in the generated text
71
+ and restores the original form.
72
+ """
73
+ if not original_entities:
74
+ return text
75
+
76
+ result = text
77
+ for entity in original_entities:
78
+ # Check if entity is already present in correct form
79
+ if entity in result:
80
+ continue
81
+
82
+ # Try case-insensitive match
83
+ pattern = re.compile(re.escape(entity), re.IGNORECASE)
84
+ if pattern.search(result):
85
+ result = pattern.sub(entity, result, count=1)
86
+ logger.debug(f"Restored entity: {entity}")
87
+
88
+ return result
89
+
90
+ def format_output(self, text: str) -> str:
91
+ """Apply final formatting (capitalisation, punctuation, spacing)."""
92
+ if not text:
93
+ return ""
94
+
95
+ result = text.strip()
96
+
97
+ # Ensure first letter is capitalised
98
+ if result and result[0].islower():
99
+ result = result[0].upper() + result[1:]
100
+
101
+ # Ensure text ends with punctuation
102
+ if result and result[-1] not in '.!?':
103
+ result += '.'
104
+
105
+ # Capitalise after sentence-ending punctuation
106
+ result = re.sub(
107
+ r'([.!?]\s+)([a-z])',
108
+ lambda m: m.group(1) + m.group(2).upper(),
109
+ result
110
+ )
111
+
112
+ # Fix "i" → "I" when standalone
113
+ result = re.sub(r'\bi\b', 'I', result)
114
+
115
+ # Remove trailing whitespace from lines
116
+ result = '\n'.join(line.rstrip() for line in result.split('\n'))
117
+
118
+ return result
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (175 Bytes). View file
 
src/model/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (177 Bytes). View file
 
src/model/__pycache__/base_model.cpython-312.pyc ADDED
Binary file (5.88 kB). View file
 
src/model/__pycache__/base_model.cpython-314.pyc ADDED
Binary file (6.22 kB). View file
 
src/model/__pycache__/generation_utils.cpython-312.pyc ADDED
Binary file (4.25 kB). View file
 
src/model/__pycache__/generation_utils.cpython-314.pyc ADDED
Binary file (4.81 kB). View file
 
src/model/__pycache__/lora_adapter.cpython-312.pyc ADDED
Binary file (2.9 kB). View file
 
src/model/__pycache__/style_conditioner.cpython-312.pyc ADDED
Binary file (3.09 kB). View file
 
src/model/__pycache__/style_conditioner.cpython-314.pyc ADDED
Binary file (3.61 kB). View file
 
src/model/base_model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads and wraps the base pretrained model.
3
+ Supported architectures:
4
+ - google/flan-t5-xl (recommended, 3B)
5
+ - google/flan-t5-large (780M, resource-constrained)
6
+ - facebook/bart-large (400M, excellent denoiser)
7
+ - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality)
8
+ """
9
+
10
+ from transformers import (
11
+ AutoTokenizer, AutoModelForSeq2SeqLM,
12
+ AutoModelForCausalLM, BitsAndBytesConfig
13
+ )
14
+ from peft import get_peft_model, LoraConfig, TaskType
15
+ import torch
16
+ from loguru import logger
17
+
18
+
19
+ ENCODER_DECODER_MODELS = {
20
+ "flan-t5-xl": "google/flan-t5-xl",
21
+ "flan-t5-large": "google/flan-t5-large",
22
+ "flan-t5-base": "google/flan-t5-base",
23
+ "flan-t5-small": "google/flan-t5-small",
24
+ "bart-large": "facebook/bart-large",
25
+ }
26
+
27
+ DECODER_ONLY_MODELS = {
28
+ "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
29
+ }
30
+
31
+
32
+ def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True,
33
+ lora_config_dict: dict = None):
34
+ """
35
+ Load a pretrained model with optional LoRA and quantization.
36
+
37
+ Args:
38
+ model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS
39
+ quantize: Whether to use 4-bit quantization
40
+ use_lora: Whether to apply LoRA adapters
41
+ lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.)
42
+
43
+ Returns:
44
+ Tuple of (model, tokenizer, is_seq2seq)
45
+ """
46
+ # Determine model type and HuggingFace identifier
47
+ is_seq2seq = model_key in ENCODER_DECODER_MODELS
48
+ is_causal = model_key in DECODER_ONLY_MODELS
49
+
50
+ if not is_seq2seq and not is_causal:
51
+ raise ValueError(
52
+ f"Unknown model key: '{model_key}'. "
53
+ f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}"
54
+ )
55
+
56
+ model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key)
57
+ logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})")
58
+
59
+ # Load tokenizer
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ # Configure quantization if requested
65
+ model_kwargs = {
66
+ "torch_dtype": torch.float32, # CPU-optimised: use float32 for stability
67
+ }
68
+
69
+ # Detect device
70
+ device = "cpu"
71
+ if torch.cuda.is_available():
72
+ device = "cuda"
73
+ # Use bfloat16 if Ampere+, else float16
74
+ if torch.cuda.get_device_capability()[0] >= 8:
75
+ model_kwargs["torch_dtype"] = torch.bfloat16
76
+ else:
77
+ model_kwargs["torch_dtype"] = torch.float16
78
+
79
+ if quantize and device == "cuda":
80
+ bnb_config = BitsAndBytesConfig(
81
+ load_in_4bit=True,
82
+ bnb_4bit_quant_type="nf4",
83
+ bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
84
+ bnb_4bit_use_double_quant=True,
85
+ )
86
+ model_kwargs["quantization_config"] = bnb_config
87
+ logger.info("Using 4-bit NF4 quantization")
88
+ elif quantize and device == "cpu":
89
+ logger.warning("Quantization requested but no GPU available, skipping")
90
+
91
+ # Load model
92
+ if is_seq2seq:
93
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
94
+ else:
95
+ model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
96
+
97
+ # Move to device if not quantized (quantized models are already on device)
98
+ if not quantize or device == "cpu":
99
+ model = model.to(device)
100
+
101
+ logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}")
102
+
103
+ # Apply LoRA if requested
104
+ if use_lora:
105
+ lora_cfg = lora_config_dict or {}
106
+ task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM
107
+
108
+ # Default target modules based on model architecture
109
+ default_targets = {
110
+ "flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
111
+ "flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
112
+ "flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
113
+ "flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
114
+ "bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"],
115
+ "llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
116
+ }
117
+
118
+ lora_config = LoraConfig(
119
+ task_type=task_type,
120
+ r=lora_cfg.get("r", 16),
121
+ lora_alpha=lora_cfg.get("lora_alpha", 32),
122
+ lora_dropout=lora_cfg.get("lora_dropout", 0.05),
123
+ target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])),
124
+ bias="none",
125
+ )
126
+
127
+ model = get_peft_model(model, lora_config)
128
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
129
+ total_params = sum(p.numel() for p in model.parameters())
130
+ logger.info(
131
+ f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total "
132
+ f"({100 * trainable_params / total_params:.2f}%)"
133
+ )
134
+
135
+ return model, tokenizer, is_seq2seq
src/model/generation_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation utilities for text correction.
3
+ Handles beam search, constrained decoding, and post-generation cleanup.
4
+ """
5
+
6
+ import torch
7
+ from transformers import PreTrainedModel, PreTrainedTokenizer
8
+ from typing import Dict, Optional, List
9
+ from loguru import logger
10
+
11
+
12
+ def generate_correction(
13
+ model: PreTrainedModel,
14
+ tokenizer: PreTrainedTokenizer,
15
+ input_ids: torch.Tensor,
16
+ attention_mask: torch.Tensor,
17
+ generation_config: Dict,
18
+ ) -> str:
19
+ """Generate corrected text from input tokens."""
20
+ # Build generation kwargs from config
21
+ gen_kwargs = {
22
+ "input_ids": input_ids,
23
+ "attention_mask": attention_mask,
24
+ "max_new_tokens": generation_config.get("max_new_tokens", 512),
25
+ "num_beams": generation_config.get("num_beams", 5),
26
+ "length_penalty": generation_config.get("length_penalty", 1.0),
27
+ "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
28
+ "min_length": generation_config.get("min_length", 10),
29
+ "early_stopping": generation_config.get("early_stopping", True),
30
+ }
31
+
32
+ # Optional sampling parameters
33
+ if generation_config.get("do_sample", False):
34
+ gen_kwargs["do_sample"] = True
35
+ gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
36
+ gen_kwargs["top_p"] = generation_config.get("top_p", 0.9)
37
+ else:
38
+ gen_kwargs["do_sample"] = False
39
+
40
+ with torch.no_grad():
41
+ output_ids = model.generate(**gen_kwargs)
42
+
43
+ # Decode, skipping special tokens
44
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
+ return generated_text.strip()
46
+
47
+
48
+ def batch_generate(
49
+ model: PreTrainedModel,
50
+ tokenizer: PreTrainedTokenizer,
51
+ texts: List[str],
52
+ generation_config: Dict,
53
+ max_length: int = 512,
54
+ ) -> List[str]:
55
+ """Generate corrections for a batch of texts."""
56
+ if not texts:
57
+ return []
58
+
59
+ results = []
60
+ # Process in mini-batches to manage memory on CPU
61
+ batch_size = generation_config.get("batch_size", 4)
62
+
63
+ for i in range(0, len(texts), batch_size):
64
+ batch_texts = texts[i:i + batch_size]
65
+
66
+ # Tokenise batch
67
+ inputs = tokenizer(
68
+ batch_texts,
69
+ max_length=max_length,
70
+ padding=True,
71
+ truncation=True,
72
+ return_tensors="pt",
73
+ )
74
+
75
+ # Move to model device
76
+ device = next(model.parameters()).device
77
+ inputs = {k: v.to(device) for k, v in inputs.items()}
78
+
79
+ # Generate
80
+ gen_kwargs = {
81
+ "max_new_tokens": generation_config.get("max_new_tokens", 512),
82
+ "num_beams": generation_config.get("num_beams", 5),
83
+ "length_penalty": generation_config.get("length_penalty", 1.0),
84
+ "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
85
+ "early_stopping": generation_config.get("early_stopping", True),
86
+ }
87
+
88
+ if generation_config.get("do_sample", False):
89
+ gen_kwargs["do_sample"] = True
90
+ gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
91
+
92
+ with torch.no_grad():
93
+ output_ids = model.generate(
94
+ input_ids=inputs["input_ids"],
95
+ attention_mask=inputs["attention_mask"],
96
+ **gen_kwargs,
97
+ )
98
+
99
+ # Decode each output
100
+ for output in output_ids:
101
+ text = tokenizer.decode(output, skip_special_tokens=True)
102
+ results.append(text.strip())
103
+
104
+ logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts")
105
+
106
+ return results
src/model/lora_adapter.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA adapter configuration and management.
3
+ Wraps PEFT LoRA utilities for applying parameter-efficient
4
+ fine-tuning to the base model.
5
+ """
6
+
7
+ from peft import LoraConfig, TaskType, get_peft_model
8
+ from typing import List, Optional
9
+ from loguru import logger
10
+
11
+
12
+ def create_lora_config(
13
+ task_type: TaskType,
14
+ r: int = 16,
15
+ lora_alpha: int = 32,
16
+ target_modules: Optional[List[str]] = None,
17
+ lora_dropout: float = 0.05,
18
+ ) -> LoraConfig:
19
+ """Create a LoRA configuration for the given task type."""
20
+ if target_modules is None:
21
+ target_modules = ["q", "v"]
22
+
23
+ config = LoraConfig(
24
+ task_type=task_type,
25
+ r=r,
26
+ lora_alpha=lora_alpha,
27
+ lora_dropout=lora_dropout,
28
+ target_modules=target_modules,
29
+ bias="none",
30
+ inference_mode=False,
31
+ )
32
+ logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}")
33
+ return config
34
+
35
+
36
+ def apply_lora(model, lora_config: LoraConfig):
37
+ """Apply LoRA adapters to a model and return the wrapped model."""
38
+ peft_model = get_peft_model(model, lora_config)
39
+ trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
40
+ total = sum(p.numel() for p in peft_model.parameters())
41
+ logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)")
42
+ return peft_model
43
+
44
+
45
+ def merge_lora_weights(model):
46
+ """Merge LoRA weights into the base model for inference.
47
+
48
+ After merging, the model behaves like a regular model with
49
+ LoRA modifications baked in, removing the adapter overhead.
50
+ """
51
+ logger.info("Merging LoRA weights into base model...")
52
+ merged = model.merge_and_unload()
53
+ logger.info("LoRA weights merged successfully")
54
+ return merged
src/model/style_conditioner.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Injects the style vector into the model via soft prompt conditioning.
3
+ The style vector is projected to the model's hidden dimension and
4
+ prepended to the input token embeddings as virtual tokens.
5
+
6
+ This technique is called "prefix tuning" / "style prefix injection".
7
+ It biases the model's attention toward the desired output style
8
+ without modifying the base model weights.
9
+
10
+ For Flan-T5: injects into encoder input embeddings
11
+ For BART: injects into encoder input embeddings
12
+ For Llama: prepends to the full input context
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class StyleConditioner(nn.Module):
20
+ """
21
+ Projects a 512-dim style vector to n_prefix_tokens virtual tokens
22
+ in the model's embedding space.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ style_dim: int = 512,
28
+ model_hidden_dim: int = 512, # T5-Small=512, Base=768, Large=1024, XL=2048
29
+ n_prefix_tokens: int = 10, # Number of virtual prefix tokens
30
+ ):
31
+ super().__init__()
32
+ self.style_dim = style_dim
33
+ self.model_hidden_dim = model_hidden_dim
34
+ self.n_prefix_tokens = n_prefix_tokens
35
+
36
+ # Project style vector to prefix embeddings
37
+ # style_dim → n_prefix_tokens * model_hidden_dim
38
+ total_output_dim = n_prefix_tokens * model_hidden_dim
39
+ self.projection = nn.Sequential(
40
+ nn.Linear(style_dim, total_output_dim),
41
+ nn.Tanh(),
42
+ )
43
+
44
+ def forward(self, style_vector: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Args:
47
+ style_vector: [batch_size, 512]
48
+ Returns:
49
+ prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim]
50
+ """
51
+ # Project: [batch, 512] → [batch, n_prefix * hidden_dim]
52
+ projected = self.projection(style_vector)
53
+
54
+ # Reshape: [batch, n_prefix * hidden_dim] → [batch, n_prefix, hidden_dim]
55
+ batch_size = style_vector.size(0)
56
+ prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim)
57
+
58
+ return prefix_embeddings
59
+
60
+
61
+ def prepend_style_prefix(
62
+ input_embeddings: torch.Tensor,
63
+ style_prefix: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Concatenates style prefix to input embeddings along sequence dimension.
67
+
68
+ Args:
69
+ input_embeddings: [batch, seq_len, hidden_dim]
70
+ style_prefix: [batch, n_prefix, hidden_dim]
71
+ Returns:
72
+ [batch, n_prefix + seq_len, hidden_dim]
73
+ """
74
+ return torch.cat([style_prefix, input_embeddings], dim=1)
src/preprocessing/__init__.py ADDED
File without changes
src/preprocessing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (183 Bytes). View file
 
src/preprocessing/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (185 Bytes). View file
 
src/preprocessing/__pycache__/dependency_parser.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
src/preprocessing/__pycache__/dyslexia_simulator.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
src/preprocessing/__pycache__/dyslexia_simulator.cpython-314.pyc ADDED
Binary file (8.15 kB). View file
 
src/preprocessing/__pycache__/ner_tagger.cpython-312.pyc ADDED
Binary file (2.7 kB). View file