KalanaPabasara commited on
Commit
f6f45d5
·
0 Parent(s):

SinCode v3 — ByT5 seq2seq + XLM-RoBERTa MLM reranker

Browse files
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment
2
+ .venv/
3
+
4
+ # Model weights — hosted on Hugging Face Hub
5
+ seq2seq/byt5-singlish-sinhala/
6
+ seq2seq/tokenized_cache/
7
+
8
+ # Large training data
9
+ seq2seq/wsd_pairs.csv
10
+
11
+ # Backup weights
12
+ *.safetensors.bak
13
+ *_backup.safetensors
14
+ final_pre_correction_backup.safetensors
15
+
16
+ # Python cache
17
+ __pycache__/
18
+ *.pyc
19
+ *.pyo
20
+ *.pyd
21
+ .Python
22
+
23
+ # Evaluation outputs (optional — remove if you want these tracked)
24
+ misc/v3_results_110.csv
25
+
26
+ # Misc
27
+ *.log
28
+ .DS_Store
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SinCode v3 — Streamlit demo UI.
3
+
4
+ Architecture: ByT5-small (seq2seq candidate generation) +
5
+ XLM-RoBERTa (MLM contextual reranking)
6
+
7
+ Two transliteration modes:
8
+ • Code-Mixed — ByT5 + MLM; retains English words where contextually apt
9
+ • Full Sinhala — mBart50 sentence-level; transliterates everything to Sinhala
10
+ """
11
+
12
+ import streamlit as st
13
+ from sincode_model import BeamSearchDecoder, SentenceTransliterator
14
+
15
+ st.set_page_config(page_title="සිංCode v3", page_icon="🇱🇰", layout="centered")
16
+
17
+ st.title("සිංCode v3")
18
+ st.caption("ByT5 seq2seq + XLM-RoBERTa MLM reranking")
19
+
20
+
21
+ @st.cache_resource(show_spinner="Loading models (ByT5 + XLM-RoBERTa)…")
22
+ def load_decoder() -> BeamSearchDecoder:
23
+ return BeamSearchDecoder()
24
+
25
+
26
+ @st.cache_resource(show_spinner="Loading mBart50 model…")
27
+ def load_transliterator() -> SentenceTransliterator:
28
+ return SentenceTransliterator()
29
+
30
+
31
+ mode = st.radio(
32
+ "Transliteration mode",
33
+ options=["Code-Mixed Output", "Full Sinhala Output"],
34
+ horizontal=True,
35
+ help=(
36
+ "**Code-Mixed**: keeps English technical/borrowed words where natural "
37
+ "(e.g. *buffer*, *bit rate*). "
38
+ "**Full Sinhala**: transliterates every word to Sinhala script "
39
+ "(e.g. *business* → ව්‍යාපාරය)."
40
+ ),
41
+ )
42
+
43
+ sentence = st.text_input(
44
+ "Enter Singlish sentence",
45
+ placeholder="e.g. mema videowe bit rate eka godak wadi nisa buffer wenawa",
46
+ )
47
+
48
+ show_trace = st.checkbox(
49
+ "Show step-by-step trace",
50
+ value=False,
51
+ disabled=(mode == "Full Sinhala Output"),
52
+ help="Trace is only available in Code-Mixed mode.",
53
+ )
54
+
55
+ if st.button("Transliterate", type="primary") and sentence.strip():
56
+ if mode == "Full Sinhala Output":
57
+ with st.spinner("Transliterating (mBart50)…"):
58
+ transliterator = load_transliterator()
59
+ result = transliterator.transliterate(sentence.strip())
60
+
61
+ st.markdown("### Result")
62
+ st.success(result)
63
+
64
+ else:
65
+ with st.spinner("Transliterating…"):
66
+ decoder = load_decoder()
67
+ result, trace_logs = decoder.decode(sentence.strip())
68
+
69
+ st.markdown("### Result")
70
+ st.success(result)
71
+
72
+ if show_trace:
73
+ st.markdown("### Trace")
74
+ for log in trace_logs:
75
+ st.markdown(log)
architecture.html ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>SinCode v3 — Architecture</title>
6
+ <script src="https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.min.js"></script>
7
+ <style>
8
+ body {
9
+ font-family: sans-serif;
10
+ background: #f8f9fa;
11
+ display: flex;
12
+ flex-direction: column;
13
+ align-items: center;
14
+ padding: 2rem;
15
+ }
16
+ h1 { color: #2c3e50; margin-bottom: 0.25rem; }
17
+ p { color: #666; margin-top: 0; margin-bottom: 2rem; }
18
+ .mermaid {
19
+ background: white;
20
+ border-radius: 12px;
21
+ padding: 2rem;
22
+ box-shadow: 0 2px 12px rgba(0,0,0,0.08);
23
+ max-width: 1200px;
24
+ width: 100%;
25
+ }
26
+ </style>
27
+ </head>
28
+ <body>
29
+ <h1>SinCode v3 — System Architecture</h1>
30
+ <p>ByT5-small · XLM-RoBERTa · mBart50-large</p>
31
+
32
+ <div class="mermaid">
33
+ flowchart TD
34
+ UI["🖥️ Streamlit UI\napp.py"]
35
+ MODE{Mode?}
36
+
37
+ UI --> MODE
38
+
39
+ subgraph MODE_FULL["Full Sinhala Mode"]
40
+ direction TB
41
+ ST["SentenceTransliterator\nseq2seq/mbart_infer.py"]
42
+ MBART["mBart50-large\nKalana001/mbart50-large-singlish-sinhala\nHF Hub · 2.4 GB"]
43
+ FIX["Compose Fix Map\nseq2seq/Compose_fix_map.json\nZWJ / Virama corrections"]
44
+ ST --> MBART
45
+ MBART -->|"raw Sinhala output"| FIX
46
+ end
47
+
48
+ subgraph MODE_MIXED["Code-Mixed Mode"]
49
+ direction TB
50
+
51
+ subgraph PHASE1["Phase 1 · Word Classification"]
52
+ direction LR
53
+ P1A["Sinhala script?\n(U+0D80–0DFF)"]
54
+ P1B["English vocab?\nenglish_20k.txt"]
55
+ P1C["Singlish\n(everything else)"]
56
+ end
57
+
58
+ subgraph PHASE2["Phase 2 · Candidate Generation (single ByT5 batch)"]
59
+ direction LR
60
+ BYT5["ByT5-small\nKalana001/byt5-small-singlish-sinhala\nHF Hub · 1.2 GB\nbeam=5 → top-5 candidates"]
61
+ SIN_PASS["Single candidate\n(word as-is)"]
62
+ ENG_CAND["English word\n+ ByT5 Sinhala alternatives"]
63
+ SIN_CAND["Top-5 ByT5\ncandidates"]
64
+ end
65
+
66
+ subgraph PHASE3["Phase 3 · Two-Pass MLM Reranking"]
67
+ direction LR
68
+ GREEDY["Pass 1 – Greedy\nBuild draft sentence\n(stale right context)"]
69
+ RESCORE["Pass 2 – Rescore\nActual decoded output\nas right context"]
70
+ MLM["XLM-RoBERTa\nKalana001/xlm-roberta-base-finetuned-sinhala\nHF Hub\nMulti-mask log-probability"]
71
+ SOFTMAX["Softmax normalise\npick argmax"]
72
+ end
73
+
74
+ PHASE1 --> PHASE2
75
+ P1A -->|Sinhala| SIN_PASS
76
+ P1B -->|English| ENG_CAND
77
+ P1C -->|Singlish| SIN_CAND
78
+ BYT5 --> ENG_CAND
79
+ BYT5 --> SIN_CAND
80
+ PHASE2 --> PHASE3
81
+ GREEDY --> MLM
82
+ MLM --> SOFTMAX
83
+ SOFTMAX --> RESCORE
84
+ RESCORE --> MLM
85
+ end
86
+
87
+ MODE -->|"Full Sinhala Output"| MODE_FULL
88
+ MODE -->|"Code-Mixed Output"| MODE_MIXED
89
+
90
+ MODE_FULL --> OUT["✅ Sinhala Output"]
91
+ MODE_MIXED --> OUT
92
+
93
+ subgraph MODELS["Models on Hugging Face Hub (Kalana001)"]
94
+ HF1["byt5-small-singlish-sinhala\n1.2 GB · ByT5-small"]
95
+ HF2["xlm-roberta-base-finetuned-sinhala\nXLM-RoBERTa"]
96
+ HF3["mbart50-large-singlish-sinhala\n2.4 GB · mBart50-large"]
97
+ end
98
+
99
+ style MODE_FULL fill:#e8f4fd,stroke:#4a9eda
100
+ style MODE_MIXED fill:#fdf3e8,stroke:#e8974a
101
+ style PHASE1 fill:#fff9e6,stroke:#cca800
102
+ style PHASE2 fill:#e8fff0,stroke:#2ecc71
103
+ style PHASE3 fill:#f4e8ff,stroke:#9b59b6
104
+ style MODELS fill:#eaf4ee,stroke:#27ae60
105
+ </div>
106
+
107
+ <script>
108
+ mermaid.initialize({ startOnLoad: true, theme: 'default', flowchart: { curve: 'basis' } });
109
+ </script>
110
+ </body>
111
+ </html>
architecture.mmd ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flowchart TD
2
+ UI["🖥️ Streamlit UI\napp.py"]
3
+ MODE{Mode?}
4
+
5
+ UI --> MODE
6
+
7
+ subgraph MODE_FULL["Full Sinhala Mode"]
8
+ direction TB
9
+ ST["SentenceTransliterator\nseq2seq/mbart_infer.py"]
10
+ MBART["mBart50-large\nKalana001/mbart50-large-singlish-sinhala\nHF Hub · 2.4 GB"]
11
+ FIX["Compose Fix Map\nseq2seq/Compose_fix_map.json\nZWJ / Virama corrections"]
12
+ ST --> MBART
13
+ MBART -->|"raw Sinhala output"| FIX
14
+ end
15
+
16
+ subgraph MODE_MIXED["Code-Mixed Mode"]
17
+ direction TB
18
+
19
+ subgraph PHASE1["Phase 1 · Word Classification"]
20
+ direction LR
21
+ P1A["Sinhala script?\n(U+0D80–0DFF)"]
22
+ P1B["English vocab?\nenglish_20k.txt"]
23
+ P1C["Singlish\n(everything else)"]
24
+ end
25
+
26
+ subgraph PHASE2["Phase 2 · Candidate Generation (single ByT5 batch)"]
27
+ direction LR
28
+ BYT5["ByT5-small\nKalana001/byt5-small-singlish-sinhala\nHF Hub · 1.2 GB\nbeam=5 → top-5 candidates"]
29
+ SIN_PASS["Single candidate\n(word as-is)"]
30
+ ENG_CAND["English word\n+ ByT5 Sinhala alternatives"]
31
+ SIN_CAND["Top-5 ByT5\ncandidates"]
32
+ end
33
+
34
+ subgraph PHASE3["Phase 3 · Two-Pass MLM Reranking"]
35
+ direction LR
36
+ GREEDY["Pass 1 – Greedy\nBuild draft sentence\n(stale right context)"]
37
+ RESCORE["Pass 2 – Rescore\nActual decoded output\nas right context"]
38
+ MLM["XLM-RoBERTa\nKalana001/xlm-roberta-base-finetuned-sinhala\nHF Hub\nMulti-mask log-probability"]
39
+ SOFTMAX["Softmax normalise\npick argmax"]
40
+ end
41
+
42
+ PHASE1 --> PHASE2
43
+ P1A -->|Sinhala| SIN_PASS
44
+ P1B -->|English| ENG_CAND
45
+ P1C -->|Singlish| SIN_CAND
46
+ BYT5 --> ENG_CAND
47
+ BYT5 --> SIN_CAND
48
+ PHASE2 --> PHASE3
49
+ GREEDY --> MLM
50
+ MLM --> SOFTMAX
51
+ SOFTMAX --> RESCORE
52
+ RESCORE --> MLM
53
+ end
54
+
55
+ MODE -->|"Full Sinhala Output"| MODE_FULL
56
+ MODE -->|"Code-Mixed Output"| MODE_MIXED
57
+
58
+ MODE_FULL --> OUT["✅ Sinhala Output"]
59
+ MODE_MIXED --> OUT
60
+
61
+ subgraph MODELS["Models on Hugging Face Hub (Kalana001)"]
62
+ HF1["byt5-small-singlish-sinhala\n1.2 GB · ByT5-small"]
63
+ HF2["xlm-roberta-base-finetuned-sinhala\nXLM-RoBERTa"]
64
+ HF3["mbart50-large-singlish-sinhala\n2.4 GB · mBart50-large"]
65
+ end
66
+
67
+ style MODE_FULL fill:#e8f4fd,stroke:#4a9eda
68
+ style MODE_MIXED fill:#fdf3e8,stroke:#e8974a
69
+ style PHASE1 fill:#fff9e6,stroke:#cca800
70
+ style PHASE2 fill:#e8fff0,stroke:#2ecc71
71
+ style PHASE3 fill:#f4e8ff,stroke:#9b59b6
72
+ style MODELS fill:#eaf4ee,stroke:#27ae60
core/__init__.py ADDED
File without changes
core/constants.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for SinCode v3.
3
+
4
+ Key difference from v2: no rule engine, no dictionary.
5
+ Candidate generation is fully handled by the ByT5 seq2seq model.
6
+ """
7
+
8
+ import re
9
+
10
+ # ─── MLM Model Path ──────────────────────────────────────────────────────────
11
+ # XLM-RoBERTa fine-tuned on Sinhala — reranks ByT5 candidates by context
12
+ DEFAULT_MLM_MODEL = "Kalana001/xlm-roberta-base-finetuned-sinhala"
13
+
14
+ # ─── ByT5 Transliterator Model Path ──────────────────────────────────────────
15
+ # Fine-tuned on 1M Singlish→Sinhala pairs — hosted on Hugging Face Hub
16
+ DEFAULT_BYT5_MODEL = "Kalana001/byt5-small-singlish-sinhala"
17
+
18
+ # ─── mBart50 Transliterator Model Path ───────────────────────────────────────
19
+ # Full-sentence Singlish→Sinhala (no English retained) — Hugging Face Hub
20
+ DEFAULT_MBART_MODEL = "Kalana001/mbart50-large-singlish-sinhala"
21
+
22
+ # ─── Corpus ───────────────────────────────────────────────────────────────────
23
+ ENGLISH_CORPUS_URL = (
24
+ "https://raw.githubusercontent.com/first20hours/google-10000-english/master/20k.txt"
25
+ )
26
+
27
+ # ─── Scoring Weights ─────────────────────────────────────────────────────────
28
+ # Pure MLM — no manual weights needed
29
+
30
+ # ─── Decoding Parameters ─────────────────────────────────────────────────────
31
+ MAX_CANDIDATES: int = 5 # ByT5 beam=5 → 5 candidates per word
32
+ MIN_ENGLISH_LEN: int = 3 # Min word length for English detection
33
+
34
+ # ─── Regex ───────────────────────────────────────────────────────────────────
35
+ PUNCT_PATTERN = re.compile(r"^(\W*)(.*?)(\W*)$")
core/decoder.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SinCode v3 — ByT5 Seq2Seq + XLM-RoBERTa MLM Reranker.
3
+
4
+ Pipeline (per word):
5
+ Sinhala script → MLM scores in context (single candidate)
6
+ English vocab → ByT5 generates Sinhala alternatives + English kept; MLM picks
7
+ Everything else → ByT5 generates top-5 candidates; MLM picks best
8
+ """
9
+
10
+ import math
11
+ import re
12
+ import torch
13
+ import logging
14
+ from typing import List, Tuple, Optional
15
+
16
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
17
+
18
+ from core.constants import (
19
+ DEFAULT_MLM_MODEL, DEFAULT_BYT5_MODEL,
20
+ MAX_CANDIDATES, MIN_ENGLISH_LEN,
21
+ PUNCT_PATTERN,
22
+ )
23
+ from core.english import ENGLISH_VOCAB
24
+ from seq2seq.infer import Transliterator
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ _SINHALA_RE = re.compile(r"[\u0D80-\u0DFF]")
29
+
30
+
31
+ class ScoredCandidate:
32
+ __slots__ = ("text", "mlm_score")
33
+
34
+ def __init__(self, text: str, mlm_score: float):
35
+ self.text = text
36
+ self.mlm_score = mlm_score
37
+
38
+
39
+ def _is_sinhala(text: str) -> bool:
40
+ return bool(_SINHALA_RE.search(text))
41
+
42
+
43
+ class BeamSearchDecoder:
44
+ """
45
+ SinCode v3 contextual decoder.
46
+
47
+ Replaces the rule engine + dictionary + hardcoded maps with a single
48
+ ByT5-small seq2seq model fine-tuned on 1,000,000 Singlish→Sinhala pairs.
49
+ XLM-RoBERTa reranks the top-5 beam candidates by masked-LM probability.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ mlm_model_name: str = DEFAULT_MLM_MODEL,
55
+ byt5_model_path: str = DEFAULT_BYT5_MODEL,
56
+ device: Optional[str] = None,
57
+ ):
58
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ logger.info("Loading MLM reranker: %s", mlm_model_name)
61
+ self.tokenizer = AutoTokenizer.from_pretrained(mlm_model_name)
62
+ self.model = AutoModelForMaskedLM.from_pretrained(mlm_model_name)
63
+ self.model.to(self.device)
64
+ self.model.eval()
65
+
66
+ logger.info("Loading ByT5 transliterator: %s", byt5_model_path)
67
+ self.transliterator = Transliterator(model_path=byt5_model_path, device=self.device)
68
+
69
+ # ── Normalization ─────────────────────────────────────────────────────────
70
+
71
+ @staticmethod
72
+ def _softmax_normalize(raw_scores: List[float]) -> List[float]:
73
+ if not raw_scores:
74
+ return []
75
+ if len(raw_scores) == 1:
76
+ return [1.0]
77
+ max_s = max(raw_scores)
78
+ exps = [math.exp(s - max_s) for s in raw_scores]
79
+ total = sum(exps)
80
+ return [e / total for e in exps]
81
+
82
+ # ── MLM batch scoring ─────────────────────────────────────────────────────
83
+
84
+ def _batch_mlm_score(
85
+ self,
86
+ left_contexts: List[str],
87
+ right_contexts: List[str],
88
+ candidates: List[str],
89
+ ) -> List[float]:
90
+ """Score each candidate with XLM-RoBERTa multi-mask log-probability."""
91
+ if not candidates:
92
+ return []
93
+
94
+ mask = self.tokenizer.mask_token
95
+ mask_token_id = self.tokenizer.mask_token_id
96
+
97
+ cand_token_ids: List[List[int]] = []
98
+ for c in candidates:
99
+ ids = self.tokenizer.encode(c, add_special_tokens=False)
100
+ cand_token_ids.append(ids if ids else [self.tokenizer.unk_token_id])
101
+
102
+ batch_texts: List[str] = []
103
+ for i in range(len(candidates)):
104
+ n_masks = len(cand_token_ids[i])
105
+ mask_str = " ".join([mask] * n_masks)
106
+ parts = [p for p in [left_contexts[i], mask_str, right_contexts[i]] if p]
107
+ batch_texts.append(" ".join(parts))
108
+
109
+ inputs = self.tokenizer(
110
+ batch_texts,
111
+ return_tensors="pt",
112
+ padding=True,
113
+ truncation=True,
114
+ ).to(self.device)
115
+
116
+ with torch.no_grad():
117
+ logits = self.model(**inputs).logits
118
+
119
+ scores: List[float] = []
120
+ for i, target_ids in enumerate(cand_token_ids):
121
+ token_ids = inputs.input_ids[i]
122
+ mask_positions = (token_ids == mask_token_id).nonzero(as_tuple=True)[0]
123
+
124
+ if mask_positions.numel() == 0 or not target_ids:
125
+ scores.append(-100.0)
126
+ continue
127
+
128
+ n = min(len(target_ids), mask_positions.numel())
129
+ total = 0.0
130
+ for j in range(n):
131
+ pos = mask_positions[j].item()
132
+ log_probs = torch.log_softmax(logits[i, pos, :], dim=0)
133
+ total += log_probs[target_ids[j]].item()
134
+
135
+ scores.append(total / n)
136
+
137
+ return scores
138
+
139
+ # ── Public decode ─────────────────────────────────────────────────────────
140
+
141
+ def decode(self, sentence: str) -> Tuple[str, List[str]]:
142
+ """
143
+ Decode a Singlish sentence word-by-word using ByT5 + XLM-RoBERTa MLM.
144
+ Returns (transliterated_sentence, trace_logs).
145
+ """
146
+ words = sentence.split()
147
+ if not words:
148
+ return "", []
149
+
150
+ # ── Phase 1: batch ByT5 candidate generation ──────────────────────────
151
+ # Collect only the words that need ByT5 (non-Sinhala), run in one pass
152
+ cores: List[str] = []
153
+ core_meta: List[tuple] = [] # (index_into_words, prefix, core, suffix, core_lower)
154
+
155
+ for i, raw in enumerate(words):
156
+ match = PUNCT_PATTERN.match(raw)
157
+ prefix, core, suffix = match.groups() if match else ("", raw, "")
158
+ if not _is_sinhala(core):
159
+ cores.append(core)
160
+ core_meta.append((i, prefix, core, suffix, core.lower()))
161
+
162
+ # Single ByT5 forward pass for all non-Sinhala words
163
+ byt5_results: List[List[str]] = (
164
+ self.transliterator.batch_candidates(cores, k=MAX_CANDIDATES)
165
+ if cores else []
166
+ )
167
+
168
+ byt5_map: dict = {} # word index → list of raw ByT5 strings
169
+ for (i, prefix, core, suffix, core_lower), cands in zip(core_meta, byt5_results):
170
+ byt5_map[i] = (prefix, suffix, core_lower, cands or [core])
171
+
172
+ word_infos: List[dict] = []
173
+ for i, raw in enumerate(words):
174
+ match = PUNCT_PATTERN.match(raw)
175
+ _, core, _ = match.groups() if match else ("", raw, "")
176
+
177
+ if _is_sinhala(core):
178
+ word_infos.append({"kind": "sinhala", "candidates": [raw]})
179
+ continue
180
+
181
+ prefix, suffix, core_lower, byt5_cands = byt5_map[i]
182
+ sinhala_cands = [prefix + c + suffix for c in byt5_cands]
183
+
184
+ if core_lower in ENGLISH_VOCAB and len(core_lower) >= MIN_ENGLISH_LEN:
185
+ candidates = [raw] + [c for c in sinhala_cands if c != raw]
186
+ word_infos.append({"kind": "english", "candidates": candidates[:MAX_CANDIDATES + 1]})
187
+ else:
188
+ word_infos.append({"kind": "singlish", "candidates": sinhala_cands})
189
+
190
+ # ── Phase 2: greedy left-to-right pass (builds dynamic left context) ──
191
+ # Right context is seeded from first ByT5 candidate (pre-decode estimate)
192
+ stable_right = [info["candidates"][0] for info in word_infos]
193
+ selected_words: List[str] = []
194
+
195
+ for t, info in enumerate(word_infos):
196
+ candidates = info["candidates"]
197
+ left_ctx = " ".join(selected_words)
198
+ right_ctx = " ".join(stable_right[t + 1:])
199
+ raw_mlm = self._batch_mlm_score(
200
+ [left_ctx] * len(candidates),
201
+ [right_ctx] * len(candidates),
202
+ candidates,
203
+ )
204
+ norm_mlm = self._softmax_normalize(raw_mlm)
205
+ best = max(zip(candidates, norm_mlm), key=lambda x: x[1])
206
+ selected_words.append(best[0])
207
+
208
+ # ── Phase 3: re-score with full decoded sentence as context ───────────
209
+ # Right context is now the actual decoded output, not the pre-decode estimate
210
+ trace_logs: List[str] = []
211
+ final_words: List[str] = []
212
+
213
+ for t, info in enumerate(word_infos):
214
+ raw_word = words[t]
215
+ kind = info["kind"]
216
+ candidates = info["candidates"]
217
+
218
+ left_ctx = " ".join(final_words)
219
+ right_ctx = " ".join(selected_words[t + 1:])
220
+
221
+ raw_mlm = self._batch_mlm_score(
222
+ [left_ctx] * len(candidates),
223
+ [right_ctx] * len(candidates),
224
+ candidates,
225
+ )
226
+ norm_mlm = self._softmax_normalize(raw_mlm)
227
+
228
+ scored = sorted(
229
+ [ScoredCandidate(text=c, mlm_score=norm_mlm[i]) for i, c in enumerate(candidates)],
230
+ key=lambda x: x.mlm_score,
231
+ reverse=True,
232
+ )
233
+ best = scored[0]
234
+ final_words.append(best.text)
235
+
236
+ if kind == "sinhala":
237
+ trace_logs.append(
238
+ f"**Step {t+1}: `{raw_word}`** → `{best.text}` "
239
+ f"(Sinhala, MLM={best.mlm_score:.3f})\n"
240
+ )
241
+ else:
242
+ trace_logs.append(
243
+ f"**Step {t+1}: `{raw_word}`** → `{best.text}` "
244
+ f"(MLM={best.mlm_score:.3f})\n"
245
+ + "\n".join(f" - `{s.text}` {s.mlm_score:.3f}" for s in scored)
246
+ )
247
+
248
+ return " ".join(final_words), trace_logs
core/english.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ English vocabulary loader for SinCode v3.
3
+ Used for English passthrough detection in the decoder.
4
+ Loads purely from the 20k corpus file — no hardcoded word lists.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import requests
10
+ from typing import Set
11
+
12
+ from core.constants import ENGLISH_CORPUS_URL, MIN_ENGLISH_LEN
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _resolve_english_cache_path() -> str:
18
+ override = os.getenv("SINCODE_ENGLISH_CACHE")
19
+ if override:
20
+ return override
21
+
22
+ candidates = [
23
+ os.path.join(os.getenv("HF_HOME", ""), "english_20k.txt") if os.getenv("HF_HOME") else "",
24
+ os.path.join(os.getcwd(), "english_20k.txt"),
25
+ os.path.join(os.getenv("TMPDIR", os.getenv("TEMP", "/tmp")), "english_20k.txt"),
26
+ ]
27
+
28
+ for path in candidates:
29
+ if not path:
30
+ continue
31
+ parent = os.path.dirname(path) or "."
32
+ try:
33
+ os.makedirs(parent, exist_ok=True)
34
+ with open(path, "a", encoding="utf-8"):
35
+ pass
36
+ return path
37
+ except OSError:
38
+ continue
39
+
40
+ return "english_20k.txt"
41
+
42
+
43
+ ENGLISH_CORPUS_CACHE = _resolve_english_cache_path()
44
+
45
+
46
+ def load_english_vocab() -> Set[str]:
47
+ vocab: Set[str] = set()
48
+
49
+ if not os.path.exists(ENGLISH_CORPUS_CACHE) or os.path.getsize(ENGLISH_CORPUS_CACHE) == 0:
50
+ try:
51
+ logger.info("Downloading English corpus...")
52
+ response = requests.get(ENGLISH_CORPUS_URL, timeout=10)
53
+ response.raise_for_status()
54
+ with open(ENGLISH_CORPUS_CACHE, "wb") as f:
55
+ f.write(response.content)
56
+ except (requests.RequestException, OSError) as exc:
57
+ logger.warning("Could not download English corpus: %s", exc)
58
+ return vocab
59
+
60
+ try:
61
+ with open(ENGLISH_CORPUS_CACHE, "r", encoding="utf-8") as f:
62
+ vocab.update(
63
+ w for line in f
64
+ if (w := line.strip().lower()) and len(w) >= MIN_ENGLISH_LEN
65
+ )
66
+ except OSError as exc:
67
+ logger.warning("Could not read English corpus file: %s", exc)
68
+
69
+ logger.info("English vocabulary loaded: %d words", len(vocab))
70
+ return vocab
71
+
72
+
73
+ ENGLISH_VOCAB: Set[str] = load_english_vocab()
core/mappings.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ core/mappings.py — deprecated.
3
+
4
+ All manual Singlish→Sinhala mappings have been removed.
5
+ Correction pairs are in seq2seq/finetune_corrections.py and baked into
6
+ the ByT5 model weights via targeted correction fine-tuning.
7
+ Candidate generation is handled end-to-end by the ByT5 seq2seq model.
8
+ """
english_20k.txt ADDED
The diff for this file is too large to render. See raw diff
 
misc/dataset_110.csv ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,input,reference,split,has_code_mix,has_ambiguity,domain,notes
2
+ 1,api kalin katha kala,අපි කලින් කතා කළා,test,0,0,general,pure singlish
3
+ 2,eka honda wage thiyanawa,ඒක හොඳ වගේ තියෙනවා,test,0,1,general,wage=seems
4
+ 3,meheta thadata wessa,මෙහෙට තදට වැස්සා,test,0,1,general,thadata=very
5
+ 4,oya kiwwata mama giye,ඔයා කිව්වට මම ගියේ,test,0,0,general,contextual past
6
+ 5,mama danne na eka gena,මම දන්නෙ නෑ ඒක ගැන,test,0,1,general,eka pronoun
7
+ 6,oya awa wage na,ඔයා ආවා වගේ නෑ,test,0,1,general,wage=seems
8
+ 7,ekat ynna bri,ඒකට යන්න බැරි,test,0,0,general,ad-hoc bri=bari
9
+ 8,mama inne gedaradi,මම ඉන්නෙ ගෙදරදී,test,0,0,general,pure singlish
10
+ 9,eka heta balamu,ඒක හෙට බලමු,test,0,0,general,eka pronoun
11
+ 10,klya madi api passe yamu,කාලය මදි අපි පස්සෙ යමු,test,0,0,general,ad-hoc klya=kalaya
12
+ 11,assignment eka ada submit karanna one,assignment එක අද submit කරන්න ඕනෙ,test,1,0,education,eka after English noun
13
+ 12,exam hall eka nisa mama baya una,exam hall එක නිසා මම බය උනා,test,1,1,education,nisa=because
14
+ 13,results blnna one,results බලන්න ඕනෙ,test,1,0,education,ad-hoc blnna=balanna
15
+ 14,study group ekak hadamu,study group එකක් හදමු,test,1,0,education,ekak after English noun
16
+ 15,viva ekta prepared wage na,viva එකට prepared වගේ නෑ,test,1,1,education,wage=seems
17
+ 16,mta project ek submit krnna one,මට project එක submit කරන්න ඕනෙ,test,1,0,education,ad-hoc mta krnna
18
+ 17,hta parikshanaya thiyanawa,හෙට පරික්‍ෂණය තියෙනවා,test,0,0,education,ad-hoc hta=heta
19
+ 18,mama potha kiyawala iwara kala,මම පොත කියවලා ඉවර කළා,test,0,0,education,pure singlish
20
+ 19,prkku nisa api kalin giya,පරක්කු නිසා අපි කලින් ගියා,test,0,1,education,nisa=because
21
+ 20,prashnaya hondai wage penenawa,ප්‍රශ්නය හොඳයි වගේ පේනවා,test,0,1,education,wage=seems
22
+ 21,deployments nisa site down wuna,deployments නිසා site down උනා,test,1,1,work,nisa=because
23
+ 22,PR eka merge karanna one,PR එක merge කරන්න ඕනෙ,test,1,0,work,eka after English noun
24
+ 23,backlog eka update kala,backlog එක update කළා,test,1,0,work,eka after English noun
25
+ 24,server down nisa work karanna ba,server down නිසා work කරන්න බෑ,test,1,1,work,nisa=because
26
+ 25,meeting eka tomorrow damu,meeting එක tomorrow දාමු,test,1,0,work,code-mix preserved
27
+ 26,feedback nisa redo karanna una,feedback නිසා redo කරන්න උනා,test,1,1,work,nisa=because
28
+ 27,ape wada ada iwara wenawa,අපේ වැඩ අද ඉවර වෙනවා,test,0,0,work,pure singlish
29
+ 28,kalamanakaru hitpu nisa api katha kala,කලමනාකරු හිටපු නිසා අපි කතා කළා,test,0,1,work,nisa=because; known failure (complex OOV)
30
+ 29,me wada hondai wage penawa,මේ වැඩ හොඳයි වගේ පේනවා,test,0,1,work,wage=seems
31
+ 30,wada tika ada iwara karamu,වැඩ ටික අද ඉවර කරමු,test,0,0,work,pure singlish
32
+ 31,story eke poll ekak damma,story එකේ poll එකක් දැම්මා,test,1,0,social,eke and ekak forms
33
+ 32,oyata DM ekak yawwa,ඔයාට DM එකක් යැව්වා,test,1,0,social,ekak after English noun
34
+ 33,comment eka delete kala nisa mama danne na,comment එක delete කළා නිසා මම දන්නෙ නෑ,test,1,1,social,"nisa=because; known failure (කළා/කල, දන්නෙ/දන්නේ)"
35
+ 34,selfie ekak gannako,selfie එකක් ගන්නකෝ,test,1,0,social,ekak after English noun
36
+ 35,post eka private nisa share karanna epa,post එක private නිසා share කරන්න එපා,test,1,1,social,nisa=because
37
+ 36,oyta message krnna one,ඔයාට message කරන්න ඕනෙ,test,1,0,social,ad-hoc oyta krnna on=one
38
+ 37,api passe katha karamu,අපි පස්සෙ කතා කරමු,test,0,0,social,pure singlish
39
+ 38,eya laga pinthurayk thiyanawa,ඒයා ළඟ පින්තූරයක් තියෙනවා,test,0,0,social,ad-hoc pinthurayk
40
+ 39,oya awa wage mata hithenawa,ඔයා ආවා වගේ මට හිතෙනවා,test,0,1,social,wage=seems
41
+ 40,api passe hambawemu,අපි පස්සෙ හම්බවෙමු,test,0,0,social,pure singlish
42
+ 41,phone eka charge karanna one,phone එක charge කරන්න ඕනෙ,test,1,0,general,NEW: general code-mix (gap fix)
43
+ 42,bus eka late una,bus එක late උනා,test,1,0,general,NEW: general code-mix
44
+ 43,mama online inne,මම online ඉන්නෙ,test,1,0,general,NEW: English mid-sentence
45
+ 44,time nathi nisa heta yamu,time නැති නිසා හෙට යමු,test,1,1,general,NEW: English+nisa in general
46
+ 45,oya call eka ganna,ඔයා call එක ගන්න,test,1,0,general,NEW: general code-mix eka pattern
47
+ 46,api game yanawa heta,අපි ගමේ යනවා හෙට,test,0,1,general,NEW: game=ගමේ(village) ambig with English 'game'
48
+ 47,man heta enne na,මන් හෙට එන්නෙ නෑ,test,0,1,general,NEW: man=මං(I) ambig with English 'man'
49
+ 48,eka hari lassanai,ඒක හරි ලස්සනයි,test,0,1,general,NEW: hari=very (not OK/correct)
50
+ 49,oya kiwwa hari,ඔයා කිව්වා හරි,test,0,1,general,NEW: hari=correct (not very)
51
+ 50,kalaya ithuru krganna one,කලය ඉතුරු කරගන්න ඕනෙ,test,0,1,general,NEW: one=ඕනෙ(need) ambig with English 'one'
52
+ 51,date eka fix karanna one,date එක fix කරන්න ඕනෙ,test,1,1,general,NEW: date=English preserve; one=ඕනෙ
53
+ 52,rata yanna one,රට යන්න ඕනෙ,test,0,0,general,"NEW: rata=country, pure singlish"
54
+ 53,game eke leaderboard eka balanna,game එකේ leaderboard එක බලන්න,test,1,1,social,NEW: game=English(video game) not ගමේ
55
+ 54,api thamai hodama,අපි තමයි හොඳම,test,0,1,general,NEW: thamai=emphatic we; hodama=best; looks English but Singlish
56
+ 55,mama heta udee enawa oya enakota message ekk dnna,මම හෙට උදේ එනවා ඔයා එනකොට message එකක් දාන්න,test,0,0,general,NEW: 8-word pure singlish
57
+ 56,ape gedara langa thiyana kadeta yanna one,අපේ ගෙදර ළඟ තියෙන කඩේට යන්න ඕනෙ,test,0,0,general,NEW: 7-word with ළඟ
58
+ 57,mama assignment eka karala submit karanawa ada raa,මම assignment එක කරලා submit කරනවා අද රෑ,test,1,0,education,NEW: 8-word code-mix long
59
+ 58,oya enne naththe mokada kiyla mama danne na,ඔයා එන්නෙ නැත්තෙ මොකද කියලා මම දන්නෙ නෑ,test,0,0,general,NEW: 9-word complex clause
60
+ 59,client ekka call karala feedback eka ahanna one,client එක්ක call කරලා feedback එක අහන්න ඕනෙ,test,1,0,work,NEW: 8-word heavy code-mix
61
+ 60,mama gedara gihilla kewata passe call karannm,මම ගෙදර ගිහිල්ලා කෑවට පස්සෙ call කරන්නම්,test,1,0,general,NEW: 8-word code-mix + temporal
62
+ 61,laptop eke software update karanna one,laptop එකේ software update කරන්න ඕනෙ,test,1,0,work,NEW: 3 English words consecutive
63
+ 62,office eke wifi password eka mokakda,office එකේ wifi password එක මොකක්ද,test,1,0,work,NEW: 3 English words; question
64
+ 63,online order eka track karanna ba,online order එක track කරන්න බෑ,test,1,0,general,NEW: 3 English words
65
+ 64,email eke attachment eka download karanna,email එකේ attachment එක download කරන්න,test,1,0,work,NEW: 3 English words + double eka
66
+ 65,Instagram story eke filter eka hadanna,Instagram story එකේ filter එක හදන්න,test,1,0,social,NEW: 4 English words; social media
67
+ 66,oyge wada iwra krd,ඔයාගෙ වැඩ ඉවර කරාද,test,0,0,general,NEW: extreme vowel omission
68
+ 67,mge phone ek hack una,මගේ phone එක hack උනා,test,1,0,general,"NEW: heavy ad-hoc mmge=mage, hrk=hack"
69
+ 68,handawata ynna wenwa,හැන්දෑවට යන්න වෙනවා,test,0,0,general,"NEW: ad-hoc hndta=handeta, wenwa=wenawa"
70
+ 69,prashnya krnna oni,ප්‍රශ්‍නය කරන්න ඕනි,test,0,0,education,NEW: replaced extreme ad-hoc with more readable form
71
+ 70,apita gdra ynna oni,අපිට ගෙදර යන්න ඕනි,test,0,0,general,NEW: ad-hoc gdra=gedara
72
+ 71,mama oyata kiwwa,මම ඔයාට කිව්වා,test,0,0,general,"NEW: common words only (mama, oyata)"
73
+ 72,oya hari hondai,ඔයා හරි හොඳයි,test,0,1,general,NEW: hari=very; common words
74
+ 73,api heta yamu,අපි හෙට යමු,test,0,0,general,NEW: common words bypass test
75
+ 74,app eka crash wenawa phone eke,app එක crash වෙනවා phone එකේ,test,1,0,technology,NEW: tech domain
76
+ 75,code eka push karanna github ekata,code එක push කරන්න github එකට,test,1,0,technology,NEW: dev workflow code-mix
77
+ 76,database eka slow nisa query eka optimize karanna one,database එක slow නිසා query එක optimize කරන්න ඕනෙ,test,1,1,technology,NEW: heavy tech code-mix + nisa; long
78
+ 77,bug eka fix kala merge karanna,bug එක fix කළා merge කරන්න,test,1,0,technology,NEW: sequential actions code-mix
79
+ 78,internet eka slow wage thiyanawa,internet එක slow වගේ තියෙනවා,test,1,1,technology,NEW: tech + wage ambiguity
80
+ 79,kema hodai ada,කෑම හොඳයි අද,test,0,0,daily_life,NEW: daily life; short
81
+ 80,mama bus eke enawa,මම bus එකේ එනවා,test,1,0,daily_life,NEW: transport code-mix
82
+ 81,ganu depala ekka market giya,ගෑනු දෙපල එක්ක market ගියා,test,1,0,daily_life,NEW: colloquial + code-mix
83
+ 82,watura bonna one,වතුර බොන්න ඕනෙ,test,0,0,daily_life,NEW: health advice singlish
84
+ 83,shop eke sugar nati nisa mama giye na,shop එකේ sugar නැති නිසා මම ගියේ නෑ,test,1,1,daily_life,NEW: daily code-mix + nisa; negative
85
+ 84,hri hari,හරි හරි,test,0,0,general,NEW: 2-word repetition; common expression + ad-hoc hri=hari
86
+ 85,mta ep,මට එපා,test,0,0,general,NEW: ad-hoc mta=mata ep=epa
87
+ 86,ok hari,ok හරි,test,1,0,general,NEW: 2-word code-mix
88
+ 87,ape game hari dewal wenne,අපේ ගමේ හරි දේවල් වෙන්නේ,test,0,1,general,"NEW: game=village, hari=nice; looks English"
89
+ 88,mta dan one na,මට දැන් ඕනෙ නෑ,test,0,1,general,NEW: man+one look English but Singlish
90
+ 89,eka hari hondai wage dnuna nisa mama giya,ඒක හරි හොඳයි වගේ දැනුනා නිසා මම ගියා,test,0,1,general,NEW: hari+wage+nisa triple ambiguity; ref corrected to හොඳයි
91
+ 90,game eke mission hari amarui,game එකේ mission හරි අමාරුයි,test,0,1,general,NEW: game=video game hari=very amarui=difficult; looks English but Singlish
92
+ 91,mama heta yanawa,මම හෙට යනවා,test,0,0,general,NEW: future tense
93
+ 92,ey iye aawa,එයා ඊයේ ආවා,test,0,0,general,NEW: past tense
94
+ 93,api dan yanawa,අපි දැන් යනවා,test,0,0,general,NEW: present tense
95
+ 94,video eka balanna one,video එක බලන්න ඕනෙ,test,1,0,social,NEW: eka definite article
96
+ 95,video ekak hadamu,video එකක් හදමු,test,1,0,social,NEW: ekak indefinite
97
+ 96,video eke comment eka balanna,video එකේ comment එක බලන්න,test,1,0,social,NEW: eke possessive + double eka
98
+ 97,video ekata like ekak danna,video එකට like එකක් දාන්න,test,1,0,social,NEW: ekata dative case
99
+ 98,lecture eka record karala share karanna,lecture එක record කරලා share කරන්න,test,1,0,education,NEW: sequential code-mix actions
100
+ 99,research paper eka liyanna one heta wge,research paper එක ලියන්න ඕනෙ හෙට වගේ,test,1,0,education,NEW: long + temporal; 8 words
101
+ 100,exam eka hari amarui,exam එක හරි අමාරුයි,test,1,1,education,NEW: hari=very; difficulty context
102
+ 101,sprint eka plan karamu Monday,sprint එක plan කරමු Monday,test,1,0,work,NEW: day name preserved
103
+ 102,ape team eka deadline ekata kala,අපේ team එක deadline එකට කළා,test,1,0,work,NEW: possessive + double English
104
+ 103,standup eke mokada kiwwe,standup එකේ මොකද කිව්වෙ,test,1,0,work,NEW: question form code-mix
105
+ 104,reel eka viral una,reel එක viral උනා,test,1,0,social,NEW: social media terminology
106
+ 105,group chat eke mokada wenne,group chat එකේ මොකද වෙන්නෙ,test,1,0,social,NEW: compound English + question
107
+ 106,oyge profile picture eka lassanai,ඔයාගෙ profile picture එක ලස්සනයි,test,1,0,social,NEW: compound English noun + eka; ref corrected to ඔයාගෙ
108
+ 107,mama enne na heta,මම එන්නෙ නෑ හෙට,test,0,0,general,NEW: negation at end
109
+ 108,eka karanna epa,ඒක කරන්න එපා,test,0,0,general,NEW: prohibition form
110
+ 109,kawruwath enne na,කවුරුවත් එන්නෙ නෑ,test,0,0,general,NEW: nobody negation
111
+ 110,oya koheda ynne,ඔයා කොහේද යන්නේ,test,0,0,general,NEW: question form where
misc/dataset_40.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,input,reference,split,has_code_mix,has_ambiguity,domain,notes
2
+ 1,api kalin katha kala,අපි කලින් කතා කළා,train,0,0,general,pure singlish
3
+ 2,eka honda wage thiyanawa,ඒක හොඳ වගේ තියෙනවා,train,0,1,general,wage=seems
4
+ 3,pola nisa gedara thiyanawa,පොල නිසා ගෙදර තියෙනවා,train,0,1,general,nisa=because
5
+ 4,oya kiwwata mama giye,ඔයා කිව්වට මම ගියේ,train,0,0,general,contextual past
6
+ 5,mama danne na eka gena,මම දන්නෙ නෑ ඒක ගැන,train,0,1,general,eka pronoun
7
+ 6,oya awa wage na,ඔයා ආවා වගේ නෑ,train,0,1,general,wage=seems
8
+ 7,ekat ynna bri,ඒකට යන්න බැරි,train,0,0,general,ad hoc bri=bari
9
+ 8,mama inne gedaradi,මම ඉන්නෙ ගෙදරදී,train,0,0,general,pure singlish
10
+ 9,eka heta balamu,ඒක හෙට බලමු,train,0,0,general,eka pronoun
11
+ 10,klya madi api passe yamu,කාලය මදි අපි පස්සෙ යමු,train,0,0,general,ad hoc klya=kalaya
12
+ 11,assignment eka ada submit karanna one,assignment එක අද submit කරන්න ඕනෙ,train,1,0,education,eka after English noun
13
+ 12,exam hall eka nisa mama baya una,exam hall එක නිසා මම බය උනා,train,1,1,education,nisa=because
14
+ 13,results blnna one,results බලන්න ඕනෙ,train,1,0,education,ad hoc blnna=balanna
15
+ 14,study group ekak hadamu,study group එකක් හදමු,train,1,0,education,ekak after English noun
16
+ 15,viva ekta prepared wage na,viva එකට prepared වගේ නෑ,train,1,1,education,wage=seems
17
+ 16,mta project ek submit krnna one,මට project එක submit කරන්න ඕනෙ,train,1,0,education,ad hoc mta krnna
18
+ 17,hta parikshanaya thiyanawa,හෙට පරික්‍ෂණය තියෙනවා,train,0,0,education,ad hoc hta=heta
19
+ 18,mama poth kiyawala iwara kala,මම පොත කියවලා ඉවර කළා,train,0,0,education,pure singlish
20
+ 19,guruwaraya nisa api kalin giya,ගුරුවරයා නිසා අපි කලින් ගියා,train,0,1,education,nisa=because
21
+ 20,prashnaya honda wage penenawa,ප්‍රශ්නය හොඳ වගේ පේනවා,train,0,1,education,wage=seems
22
+ 21,deploy nisa site down wuna,deploy නිසා site down උනා,train,1,1,work,nisa=because
23
+ 22,PR eka merge karanna one,PR එක merge කරන්න ඕනෙ,train,1,0,work,eka after English noun
24
+ 23,backlog eka update kala,backlog එක update කළා,train,1,0,work,eka after English noun
25
+ 24,server down nisa work karanna ba,server down නිසා work කරන්න බෑ,train,1,1,work,nisa=because
26
+ 25,meeting eka tomorrow damu,meeting එක tomorrow දාමු,train,1,0,work,code mix preserved
27
+ 26,feedback nisa redo karanna una,feedback නිසා redo කරන්න උනා,train,1,1,work,nisa=because
28
+ 27,ape wada ada iwara wenawa,අපේ වැඩ අද ඉවර වෙනවා,train,0,0,work,pure singlish
29
+ 28,kalamanakaru apu nisa api katha kala,කලමණාකරු ආපු නිසා අපි කතා කලා,train,0,1,work,nisa=because
30
+ 29,me wada honda wage penenawa,මේ වැඩ හොඳ වගේ පේනවා,train,0,1,work,wage=seems
31
+ 30,wada tika ada iwara karamu,වැඩ ටික අද ඉවර කරමු,train,0,0,work,pure singlish
32
+ 31,story eke poll ekak damma,story එකේ poll එකක් දැම්මා,train,1,0,social,eke and ekak forms
33
+ 32,oyata DM ekak yewwa,ඔයාට DM එකක් යැව්වා,train,1,0,social,ekak after English noun
34
+ 33,comment eka delete kala nisa mama danne na,comment එක delete කල නිසා මම දන්නේ නෑ,train,1,1,social,nisa=because
35
+ 34,selfie ekak gannako,selfie එකක් ගන්නකෝ,train,1,0,social,ekak after English noun
36
+ 35,post eka private nisa share karanna epa,post එක private නිසා share කරන්න එපා,train,1,1,social,nisa=because
37
+ 36,oyta message krnna on,ඔයාට message කරන්න ඕනෙ,train,1,0,social,ad hoc oyta krnna
38
+ 37,oya passe katha karamu,ඔයා පස්සෙ කතා කරමු,train,0,0,social,pure singlish
39
+ 38,eya laga pinthurayk thiyanawa,ඒයා ළඟ පින්තූරයක් තියෙනවා,train,0,0,social,ad hoc pinthurayk
40
+ 39,oya awa wage mata hithenawa,ඔයා ආවා වගේ මට හිතෙනවා,train,0,1,social,wage=seems
41
+ 40,api passe hambawemu,අපි පස්සෙ හම්බවෙමු,train,0,0,social,pure singlish
misc/evaluate.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SinCode v3 — Evaluation Script
3
+
4
+ Supports two evaluation modes selected via --mode:
5
+
6
+ system Full v3 pipeline (ByT5 + two-pass MLM). Default.
7
+ ablation Side-by-side comparison of two configurations:
8
+ (A) ByT5 top-1 only — no MLM reranking
9
+ (B) ByT5 + MLM — full Code-Mixed pipeline
10
+ Proves the contribution of the XLM-RoBERTa reranker.
11
+
12
+ Note: mBart50 is intentionally excluded from evaluation here because the
13
+ reference dataset uses code-mixed targets (English words preserved). mBart50
14
+ produces full-Sinhala output by design, making a metric comparison against
15
+ code-mixed references invalid. Evaluate mBart50 separately with a dataset
16
+ whose references are fully in Sinhala script.
17
+
18
+ Usage:
19
+ python misc/evaluate.py --dataset misc/dataset_110.csv
20
+ python misc/evaluate.py --dataset misc/dataset_110.csv --mode ablation
21
+ python misc/evaluate.py --dataset misc/dataset_110.csv --mode ablation --out misc/results.csv
22
+
23
+ CSV columns required: id, input, reference
24
+ Optional columns (used for grouping): category, domain, has_code_mix, has_ambiguity
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import argparse
30
+ import csv
31
+ import json
32
+ import logging
33
+ import math
34
+ import os
35
+ import sys
36
+ from collections import defaultdict
37
+ from dataclasses import dataclass
38
+ from typing import Dict, List, Optional
39
+
40
+ # ── Path setup ────────────────────────────────────────────────────────────────
41
+
42
+ ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
43
+ if ROOT not in sys.path:
44
+ sys.path.insert(0, ROOT)
45
+
46
+ logging.basicConfig(level=logging.WARNING)
47
+
48
+ # ── Metrics ───────────────────────────────────────────────────────────────────
49
+
50
+ def _levenshtein(a: str, b: str) -> int:
51
+ if not a: return len(b)
52
+ if not b: return len(a)
53
+ prev = list(range(len(b) + 1))
54
+ for i, ca in enumerate(a, 1):
55
+ curr = [i] + [0] * len(b)
56
+ for j, cb in enumerate(b, 1):
57
+ cost = 0 if ca == cb else 1
58
+ curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + cost)
59
+ prev = curr
60
+ return prev[-1]
61
+
62
+
63
+ def _levenshtein_tokens(a: list, b: list) -> int:
64
+ if not a: return len(b)
65
+ if not b: return len(a)
66
+ prev = list(range(len(b) + 1))
67
+ for i, ta in enumerate(a, 1):
68
+ curr = [i] + [0] * len(b)
69
+ for j, tb in enumerate(b, 1):
70
+ cost = 0 if ta == tb else 1
71
+ curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + cost)
72
+ prev = curr
73
+ return prev[-1]
74
+
75
+
76
+ def cer(pred: str, ref: str) -> float:
77
+ if not ref: return 0.0 if not pred else 1.0
78
+ return _levenshtein(pred, ref) / max(len(ref), 1)
79
+
80
+
81
+ def wer(pred: str, ref: str) -> float:
82
+ pt, rt = pred.split(), ref.split()
83
+ if not rt: return 0.0 if not pt else 1.0
84
+ return _levenshtein_tokens(pt, rt) / max(len(rt), 1)
85
+
86
+
87
+ def token_accuracy(pred: str, ref: str) -> float:
88
+ pt, rt = pred.split(), ref.split()
89
+ if not rt: return 0.0 if pt else 1.0
90
+ return sum(p == r for p, r in zip(pt, rt)) / max(len(rt), 1)
91
+
92
+
93
+ def bleu(pred: str, ref: str, max_n: int = 4) -> float:
94
+ from collections import Counter
95
+ pt, rt = pred.split(), ref.split()
96
+ if not pt or not rt: return 0.0
97
+ n_max = min(max_n, len(pt), len(rt))
98
+ if n_max == 0: return 0.0
99
+ brevity = min(1.0, len(pt) / len(rt))
100
+ log_avg = 0.0
101
+ for n in range(1, n_max + 1):
102
+ pc = Counter(tuple(pt[i:i+n]) for i in range(len(pt)-n+1))
103
+ rc = Counter(tuple(rt[i:i+n]) for i in range(len(rt)-n+1))
104
+ clipped = sum(min(c, rc[ng]) for ng, c in pc.items())
105
+ total = max(sum(pc.values()), 1)
106
+ prec = clipped / total
107
+ if prec == 0: return 0.0
108
+ log_avg += math.log(prec) / n_max
109
+ return brevity * math.exp(log_avg)
110
+
111
+
112
+ def exact_match(pred: str, ref: str) -> float:
113
+ return 1.0 if pred.strip() == ref.strip() else 0.0
114
+
115
+
116
+ # ── Data model ────────────────────────────────────────────────────────────────
117
+
118
+ @dataclass
119
+ class TestCase:
120
+ id: int
121
+ input: str
122
+ reference: str
123
+ domain: str = "general"
124
+ has_code_mix: bool = False
125
+ has_ambiguity: bool = False
126
+
127
+
128
+ @dataclass
129
+ class Result:
130
+ test_case: TestCase
131
+ system: str
132
+ prediction: str
133
+ cer_score: float
134
+ wer_score: float
135
+ token_acc: float
136
+ bleu_score: float
137
+ exact: float
138
+
139
+
140
+ def _score(tc: TestCase, pred: str, system: str) -> Result:
141
+ return Result(
142
+ test_case=tc,
143
+ system=system,
144
+ prediction=pred,
145
+ cer_score=cer(pred, tc.reference),
146
+ wer_score=wer(pred, tc.reference),
147
+ token_acc=token_accuracy(pred, tc.reference),
148
+ bleu_score=bleu(pred, tc.reference),
149
+ exact=exact_match(pred, tc.reference),
150
+ )
151
+
152
+
153
+ # ── Test set loader ───────────────────────────────────────────────────────────
154
+
155
+ def load_dataset(csv_path: str) -> List[TestCase]:
156
+ cases = []
157
+ with open(csv_path, "r", encoding="utf-8", newline="") as f:
158
+ reader = csv.DictReader(f)
159
+ fields = set(reader.fieldnames or [])
160
+ if not {"input", "reference"}.issubset(fields):
161
+ raise ValueError(f"CSV must have 'input' and 'reference' columns. Found: {fields}")
162
+ for row in reader:
163
+ inp = (row.get("input") or "").strip().replace("\n", " ")
164
+ ref = (row.get("reference") or "").strip().replace("\n", " ")
165
+ if not inp or not ref:
166
+ continue
167
+ cases.append(TestCase(
168
+ id=int(row.get("id") or 0),
169
+ input=inp,
170
+ reference=ref,
171
+ domain=(row.get("domain") or row.get("category") or "general").strip(),
172
+ has_code_mix=bool(int(row.get("has_code_mix") or 0)),
173
+ has_ambiguity=bool(int(row.get("has_ambiguity") or 0)),
174
+ ))
175
+ return cases
176
+
177
+
178
+ # ── Model loaders ─────────────────────────────────────────────────────────────
179
+
180
+ def _load_v3_decoder():
181
+ from sincode_model import BeamSearchDecoder
182
+ print(" Loading ByT5 + XLM-RoBERTa (Code-Mixed pipeline)...")
183
+ return BeamSearchDecoder()
184
+
185
+
186
+ def _byt5_top1_predict(decoder, sentence: str) -> str:
187
+ """ByT5 top-1 only — pick first beam candidate, skip MLM reranking."""
188
+ from core.constants import PUNCT_PATTERN
189
+ from core.decoder import _is_sinhala
190
+
191
+ words = sentence.split()
192
+ output = []
193
+ cores = [re.sub(r"^\W*|\W*$", "", w) for w in words]
194
+ non_sinhala = [c for c in cores if not _is_sinhala(c) and c]
195
+
196
+ if not non_sinhala:
197
+ return sentence
198
+
199
+ byt5_results = decoder.transliterator.batch_candidates(non_sinhala, k=1)
200
+ byt5_iter = iter(byt5_results)
201
+
202
+ for raw, core in zip(words, cores):
203
+ m = PUNCT_PATTERN.match(raw)
204
+ prefix, _, suffix = m.groups() if m else ("", raw, "")
205
+ if _is_sinhala(core) or not core:
206
+ output.append(raw)
207
+ else:
208
+ cands = next(byt5_iter, [core])
209
+ output.append(prefix + (cands[0] if cands else core) + suffix)
210
+ return " ".join(output)
211
+
212
+
213
+ # ── Reporting ─────────────────────────────────────────────────────────────────
214
+
215
+ def _avg(vals: List[float]) -> float:
216
+ return sum(vals) / len(vals) if vals else 0.0
217
+
218
+
219
+ def _print_table(label: str, results: List[Result]):
220
+ print(f"\n{'='*74}")
221
+ print(f" {label} (n={len(results)})")
222
+ print(f"{'='*74}")
223
+ print(f" {'ID':<5} {'Domain':<14} {'CM':>3} {'Am':>3} {'CER':>6} {'WER':>6} {'TokAcc':>7} {'BLEU':>6} {'EM':>4}")
224
+ print(f" {'-'*66}")
225
+ for r in results:
226
+ tc = r.test_case
227
+ print(
228
+ f" {tc.id:<5} {tc.domain[:13]:<14} {'Y' if tc.has_code_mix else 'N':>3} "
229
+ f"{'Y' if tc.has_ambiguity else 'N':>3} "
230
+ f"{r.cer_score:>6.3f} {r.wer_score:>6.3f} {r.token_acc:>7.3f} "
231
+ f"{r.bleu_score:>6.3f} {r.exact:>4.0f}"
232
+ )
233
+ print(f" {'-'*66}")
234
+ print(
235
+ f" {'AVERAGE':<26} "
236
+ f"{_avg([r.cer_score for r in results]):>6.3f} "
237
+ f"{_avg([r.wer_score for r in results]):>6.3f} "
238
+ f"{_avg([r.token_acc for r in results]):>7.3f} "
239
+ f"{_avg([r.bleu_score for r in results]):>6.3f} "
240
+ f"{_avg([r.exact for r in results]):>4.2f}"
241
+ )
242
+
243
+ # Per-domain breakdown
244
+ by_domain: Dict[str, List[Result]] = defaultdict(list)
245
+ for r in results:
246
+ by_domain[r.test_case.domain].append(r)
247
+ if len(by_domain) > 1:
248
+ print(f"\n Per-domain averages (CER / WER / TokAcc):")
249
+ for dom, rs in sorted(by_domain.items()):
250
+ print(
251
+ f" {dom:<18} n={len(rs):<4} "
252
+ f"CER={_avg([r.cer_score for r in rs]):.3f} "
253
+ f"WER={_avg([r.wer_score for r in rs]):.3f} "
254
+ f"TokAcc={_avg([r.token_acc for r in rs]):.3f}"
255
+ )
256
+
257
+ # Code-mixed vs pure Singlish
258
+ cm_r = [r for r in results if r.test_case.has_code_mix]
259
+ pure_r = [r for r in results if not r.test_case.has_code_mix]
260
+ if cm_r and pure_r:
261
+ print(
262
+ f"\n Code-mixed (n={len(cm_r):<3}): "
263
+ f"CER={_avg([r.cer_score for r in cm_r]):.3f} "
264
+ f"WER={_avg([r.wer_score for r in cm_r]):.3f}"
265
+ )
266
+ print(
267
+ f" Pure Singlish (n={len(pure_r):<3}): "
268
+ f"CER={_avg([r.cer_score for r in pure_r]):.3f} "
269
+ f"WER={_avg([r.wer_score for r in pure_r]):.3f}"
270
+ )
271
+
272
+
273
+ def _print_ablation(a_res: List[Result], b_res: List[Result]):
274
+ print(f"\n{'='*74}")
275
+ print(" ABLATION STUDY — MLM Reranking Contribution")
276
+ print(f" (A) ByT5 top-1 only | (B) ByT5 + XLM-RoBERTa MLM reranking")
277
+ print(f"{'='*74}")
278
+ print(f" {'Metric':<22} {'(A) ByT5-top1':>14} {'(B) ByT5+MLM':>13} {'Δ (B−A)':>10}")
279
+ print(f" {'-'*64}")
280
+
281
+ metrics = [
282
+ ("CER (↓ better)", [r.cer_score for r in a_res], [r.cer_score for r in b_res], True),
283
+ ("WER (↓ better)", [r.wer_score for r in a_res], [r.wer_score for r in b_res], True),
284
+ ("Token Acc (↑)", [r.token_acc for r in a_res], [r.token_acc for r in b_res], False),
285
+ ("BLEU (↑ better)", [r.bleu_score for r in a_res], [r.bleu_score for r in b_res], False),
286
+ ("Exact Match (↑)", [r.exact for r in a_res], [r.exact for r in b_res], False),
287
+ ]
288
+
289
+ for label, a_vals, b_vals, lower_is_better in metrics:
290
+ a_avg, b_avg = _avg(a_vals), _avg(b_vals)
291
+ delta = b_avg - a_avg
292
+ improved = (delta < 0) if lower_is_better else (delta > 0)
293
+ print(
294
+ f" {label:<22} {a_avg:>14.4f} {b_avg:>13.4f} "
295
+ f" {'✓' if improved else '✗'}{delta:>+8.4f}"
296
+ )
297
+
298
+ print(f"\n ✓ B vs A isolates the contribution of XLM-RoBERTa MLM reranking.")
299
+ print(f" ✓ If B > A: the two-pass reranker justifies its computational cost.")
300
+
301
+ # Subcategory breakdown
302
+ for sublabel, filter_fn in [
303
+ ("Code-mixed only", lambda r: r.test_case.has_code_mix),
304
+ ("Ambiguous only", lambda r: r.test_case.has_ambiguity),
305
+ ("Pure Singlish", lambda r: not r.test_case.has_code_mix),
306
+ ]:
307
+ a_sub = [r for r in a_res if filter_fn(r)]
308
+ b_sub = [r for r in b_res if filter_fn(r)]
309
+ if not a_sub:
310
+ continue
311
+ print(f"\n {sublabel} (n={len(a_sub)}):")
312
+ print(f" {'':20} {'(A)':>10} {'(B)':>10} {'Δ':>10}")
313
+ for ml, getter, low in [("CER", lambda r: r.cer_score, True), ("WER", lambda r: r.wer_score, True), ("TokAcc", lambda r: r.token_acc, False)]:
314
+ av, bv = _avg([getter(r) for r in a_sub]), _avg([getter(r) for r in b_sub])
315
+ d = bv - av
316
+ imp = (d < 0) if low else (d > 0)
317
+ print(
318
+ f" {ml:<20} {av:>10.4f} {bv:>10.4f} "
319
+ f" {'✓' if imp else '✗'}{d:>+7.4f}"
320
+ )
321
+
322
+
323
+ def _load_baseline(path: str) -> dict:
324
+ with open(path, "r", encoding="utf-8") as f:
325
+ return json.load(f)
326
+
327
+
328
+ def _print_v2_comparison(b_res: List[Result], baseline: dict):
329
+ n = len(b_res)
330
+ v3 = {
331
+ "exact_match": _avg([r.exact for r in b_res]),
332
+ "cer": _avg([r.cer_score for r in b_res]),
333
+ "wer": _avg([r.wer_score for r in b_res]),
334
+ "bleu": _avg([r.bleu_score for r in b_res]),
335
+ "token_acc": _avg([r.token_acc for r in b_res]),
336
+ }
337
+ v2_label = baseline.get("system", "v2 baseline")
338
+
339
+ print(f"\n{'='*74}")
340
+ print(f" SinCode v2 vs SinCode v3 — Head-to-Head (n={n})")
341
+ print(f" v2: {v2_label}")
342
+ print(f" v3: ByT5-small seq2seq + XLM-RoBERTa MLM reranking")
343
+ print(f"{'='*74}")
344
+ print(f" {'Metric':<22} {'v2 (baseline)':>14} {'v3 (ours)':>10} {'Δ (v3−v2)':>12}")
345
+ print(f" {'-'*62}")
346
+
347
+ metrics = [
348
+ ("Exact Match (↑)", "exact_match", False),
349
+ ("CER (↓ better)", "cer", True),
350
+ ("WER (↓ better)", "wer", True),
351
+ ("BLEU (↑ better)", "bleu", False),
352
+ ("Token Acc (↑)", "token_acc", False),
353
+ ]
354
+ for label, key, lower_is_better in metrics:
355
+ v2v = baseline.get(key, 0.0)
356
+ v3v = v3[key]
357
+ delta = v3v - v2v
358
+ improved = (delta < 0) if lower_is_better else (delta > 0)
359
+ arrow = "↑" if (delta > 0) else ("↓" if delta < 0 else "=")
360
+ print(
361
+ f" {label:<22} {v2v:>14.4f} {v3v:>10.4f} "
362
+ f" {'✓' if improved else '✗'} {arrow}{abs(delta):>+8.4f}"
363
+ )
364
+
365
+ if baseline.get("notes"):
366
+ print(f"\n Note: {baseline['notes']}")
367
+
368
+
369
+ def _save_csv(results_by_system: Dict[str, List[Result]], out_path: str):
370
+ rows = []
371
+ for system, results in results_by_system.items():
372
+ for r in results:
373
+ rows.append({
374
+ "system": system,
375
+ "id": r.test_case.id,
376
+ "domain": r.test_case.domain,
377
+ "has_code_mix": int(r.test_case.has_code_mix),
378
+ "has_ambiguity": int(r.test_case.has_ambiguity),
379
+ "input": r.test_case.input,
380
+ "reference": r.test_case.reference,
381
+ "prediction": r.prediction,
382
+ "cer": f"{r.cer_score:.4f}",
383
+ "wer": f"{r.wer_score:.4f}",
384
+ "token_acc": f"{r.token_acc:.4f}",
385
+ "bleu": f"{r.bleu_score:.4f}",
386
+ "exact_match": f"{r.exact:.0f}",
387
+ })
388
+ with open(out_path, "w", encoding="utf-8", newline="") as f:
389
+ w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
390
+ w.writeheader()
391
+ w.writerows(rows)
392
+ print(f"\n Results saved -> {out_path}")
393
+
394
+
395
+ # ── Main ──────────────────────────────────────────────────────────────────────
396
+
397
+ def main():
398
+ parser = argparse.ArgumentParser(description="SinCode v3 evaluation")
399
+ parser.add_argument("--dataset", required=True,
400
+ help="Path to evaluation CSV (dataset_110.csv or dataset_40.csv)")
401
+ parser.add_argument("--mode", default="system",
402
+ choices=["system", "ablation"],
403
+ help="Evaluation mode (default: system)")
404
+ parser.add_argument("--out", default=None,
405
+ help="Optional path to save results CSV")
406
+ parser.add_argument("--baseline", default=None,
407
+ help="Path to v2 baseline JSON (e.g. misc/v2_baseline.json) for head-to-head comparison")
408
+ args = parser.parse_args()
409
+
410
+ print(f"\nLoading dataset: {args.dataset}")
411
+ test_cases = load_dataset(args.dataset)
412
+ print(f" {len(test_cases)} test cases loaded.")
413
+
414
+ results_by_system: Dict[str, List[Result]] = {}
415
+ a_results: List[Result] = []
416
+ b_results: List[Result] = []
417
+
418
+ decoder = _load_v3_decoder()
419
+
420
+ if args.mode == "ablation":
421
+ print("\nRunning (A) ByT5 top-1 only...")
422
+ a_results = [_score(tc, _byt5_top1_predict(decoder, tc.input), "byt5_top1") for tc in test_cases]
423
+ results_by_system["byt5_top1"] = a_results
424
+
425
+ print("\nRunning (B) ByT5 + MLM reranking...")
426
+ b_results = [_score(tc, decoder.decode(tc.input)[0], "byt5_mlm") for tc in test_cases]
427
+ results_by_system["byt5_mlm"] = b_results
428
+
429
+ if args.mode == "system":
430
+ _print_table("v3 Code-Mixed Pipeline (ByT5 + XLM-RoBERTa MLM)", b_results)
431
+ elif args.mode == "ablation":
432
+ _print_table("(A) ByT5 top-1 only", a_results)
433
+ _print_table("(B) ByT5 + MLM reranking", b_results)
434
+ _print_ablation(a_results, b_results)
435
+
436
+ if args.baseline:
437
+ baseline = _load_baseline(args.baseline)
438
+ _print_v2_comparison(b_results, baseline)
439
+
440
+ if args.out:
441
+ _save_csv(results_by_system, args.out)
442
+
443
+
444
+ if __name__ == "__main__":
445
+ main()
446
+
misc/upload_mlm_to_hf.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload the fine-tuned XLM-RoBERTa MLM model to HuggingFace Hub.
3
+ Run from: C:\Y5_Docs\FYP\SinCode\SinCode_v3
4
+ Usage: python misc/upload_mlm_to_hf.py --token YOUR_HF_WRITE_TOKEN
5
+ """
6
+
7
+ import argparse
8
+ from pathlib import Path
9
+ from huggingface_hub import HfApi
10
+
11
+ MODEL_LOCAL_PATH = Path(
12
+ r"C:\Y5_Docs\FYP\SinCode\SinCode_v2-20260315T161648Z-1-001"
13
+ r"\SinCode_v2\SinCode\SinCode\xlm-roberta-sinhala-v5-strict-full\final"
14
+ )
15
+ REPO_ID = "Kalana001/xlm-roberta-base-finetuned-sinhala"
16
+
17
+ FILES_TO_UPLOAD = [
18
+ "config.json",
19
+ "model.safetensors",
20
+ "tokenizer.json",
21
+ "tokenizer_config.json",
22
+ ]
23
+
24
+ def main():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--token", required=True, help="HuggingFace write-access token")
27
+ args = parser.parse_args()
28
+
29
+ api = HfApi(token=args.token)
30
+
31
+ print(f"Uploading to: {REPO_ID}")
32
+ for filename in FILES_TO_UPLOAD:
33
+ local_file = MODEL_LOCAL_PATH / filename
34
+ if not local_file.exists():
35
+ print(f" SKIP (not found): {filename}")
36
+ continue
37
+ size_mb = round(local_file.stat().st_size / 1024 / 1024, 1)
38
+ print(f" Uploading {filename} ({size_mb} MB)...")
39
+ api.upload_file(
40
+ path_or_fileobj=str(local_file),
41
+ path_in_repo=filename,
42
+ repo_id=REPO_ID,
43
+ repo_type="model",
44
+ )
45
+ print(f" Done: {filename}")
46
+
47
+ print(f"\nAll files uploaded to https://huggingface.co/{REPO_ID}")
48
+
49
+ if __name__ == "__main__":
50
+ main()
misc/v2_baseline.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "system": "SinCode v2 (rule-based + dictionary)",
3
+ "samples": 110,
4
+ "exact_match": 0.8364,
5
+ "cer": 0.0122,
6
+ "wer": 0.0407,
7
+ "bleu": 0.8861,
8
+ "token_acc": 0.9593,
9
+ "notes": "Measured on dataset_110.csv test split. Avg time per sentence: 0.03s (3.34s total)."
10
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers>=4.40.0
2
+ torch>=2.2.0
3
+ sentencepiece
4
+ datasets
5
+ streamlit
6
+ pandas
seq2seq/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # seq2seq — ByT5 and mBart50 inference wrappers
seq2seq/compose_fix_map.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ක්\\sර": "ක්‍ර",
3
+ "ක්\\sරි": "ක්‍රි",
4
+ "ක්\\sරී": "ක්‍රී",
5
+ "ක්\\sරා": "ක්‍රා",
6
+ "ද්\\sර": "ද්‍ර",
7
+ "ද්\\sරි": "ද්‍රි",
8
+ "ද්\\sරී": "ද්‍රී",
9
+ "ද්\\sරා": "ද්‍රා",
10
+ "ට්\\sර": "ට්‍ර",
11
+ "ට්\\sරි": "ට්‍රි",
12
+ "ට්\\sරී": "ට්‍රී",
13
+ "ට්\\sරා": "ට්‍රා",
14
+ "ත්\\sර": "ත්‍ර",
15
+ "ත්\\sරි": "ත්‍රි",
16
+ "ත්\\sරී": "ත්‍රී",
17
+ "ත්\\sරා": "ත්‍රා",
18
+ "ප්\\sර": "ප්‍ර",
19
+ "ප්\\sරි": "ප්‍රි",
20
+ "ප්\\sරී": "ප්‍රී",
21
+ "ප්\\sරා": "ප්‍රා",
22
+ "බ්\\sර": "බ්‍ර",
23
+ "බ්\\sරි": "බ්‍රි",
24
+ "බ්\\sරී": "බ්‍රී",
25
+ "බ්\\sරා": "බ්‍රා",
26
+ "ග්\\sර": "ග්‍ර",
27
+ "ග්\\sරි": "ග්‍රි",
28
+ "ග්\\sරී": "ග්‍රී",
29
+ "ග්\\sරා": "ග්‍රා",
30
+ "ෂ්\\sර": "ෂ්‍ර",
31
+ "ෂ්\\sරි": "ෂ්‍රි",
32
+ "ෂ්\\sරී": "ෂ්‍රී",
33
+ "ෂ්\\sරා": "ෂ්‍රා",
34
+ "ශ්\\sර": "ශ්‍ර",
35
+ "ශ්\\sරි": "ශ්‍රි",
36
+ "ශ්\\sරී": "ශ්‍රී",
37
+ "ශ්\\sරා": "ශ්‍රා",
38
+ "ව්\\sය": "ව්‍ය",
39
+ "ව්\\sයා": "ව්‍යා",
40
+ "ද්\\sය": "ද්‍ය",
41
+ "ද්\\sයා": "ද්‍යා",
42
+ "න්\\sය": "න්‍ය",
43
+ "ධ්\\sය": "ධ්‍ය",
44
+ "ධ්\\sයා": "ධ්‍යා",
45
+ "ද්\\sයු": "ද්‍යු"
46
+ }
seq2seq/finetune_corrections.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ seq2seq/finetune_corrections.py
3
+
4
+ Targeted correction fine-tune for the already-trained ByT5 model.
5
+
6
+ Problem: ByT5 struggles with short/ambiguous tokens like "na"→නෑ, "ba"→බෑ,
7
+ extreme abbreviations like "mn"→මං, and colloquial negations.
8
+
9
+ Solution: Inject high-confidence correction pairs (from core/mappings.py)
10
+ heavily repeated, mixed with a random sample of the original
11
+ training data to prevent catastrophic forgetting.
12
+
13
+ The output is saved to byt5-singlish-sinhala/final/ (overwrites in place).
14
+ Run from the project root:
15
+ python seq2seq/finetune_corrections.py
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import random
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ ROOT = Path(__file__).parent.parent
25
+ if str(ROOT) not in sys.path:
26
+ sys.path.insert(0, str(ROOT))
27
+
28
+ import torch
29
+ from datasets import Dataset
30
+ from transformers import (
31
+ AutoTokenizer,
32
+ AutoModelForSeq2SeqLM,
33
+ Seq2SeqTrainer,
34
+ Seq2SeqTrainingArguments,
35
+ default_data_collator,
36
+ )
37
+
38
+ # ── Config ────────────────────────────────────────────────────────────────────
39
+
40
+ MODEL_PATH = ROOT / "seq2seq" / "byt5-singlish-sinhala" / "final"
41
+ DATA_PATH = ROOT / "seq2seq" / "wsd_pairs.csv"
42
+ OUTPUT_DIR = ROOT / "seq2seq" / "byt5-singlish-sinhala" / "final" # overwrite in place
43
+
44
+ REPEAT = 500 # how many times each correction pair is repeated
45
+ BG_SAMPLES = 50_000 # random background pairs from wsd_pairs.csv to prevent forgetting
46
+ MAX_INPUT_LEN = 64
47
+ MAX_TARGET_LEN = 64
48
+ BATCH_SIZE = 32
49
+ LR = 5e-5 # low LR — gentle correction, not retraining
50
+ EPOCHS = 1
51
+ SEED = 42
52
+
53
+ # ── Correction pairs (sourced from core/mappings.py) ─────────────────────────
54
+ # Only include pairs where ByT5 is known to be unreliable.
55
+ # English-safe tokens (pr, dm, ai…) are excluded — they never reach ByT5.
56
+
57
+ CORRECTIONS = [
58
+ # negation — most critical
59
+ ("na", "නෑ"),
60
+ ("naa", "නෑ"),
61
+ ("ba", "බෑ"),
62
+ ("bari", "බැරි"),
63
+ ("bri", "බැරි"),
64
+ ("nathi", "නැති"),
65
+ ("nati", "නැති"),
66
+ ("naththe", "නැත්තෙ"),
67
+ ("epa", "එපා"),
68
+ ("ep", "එපා"),
69
+ # pronouns / first person
70
+ ("mn", "මං"),
71
+ ("mama", "මම"),
72
+ ("mage", "මගේ"),
73
+ ("mge", "මගේ"),
74
+ ("oya", "ඔයා"),
75
+ ("oyaa", "ඔයා"),
76
+ ("api", "අපි"),
77
+ ("mata", "මට"),
78
+ ("mta", "මට"),
79
+ ("oyata", "ඔයාට"),
80
+ ("oyta", "ඔයාට"),
81
+ ("oyage", "ඔයාගේ"),
82
+ ("oyge", "ඔයාගෙ"),
83
+ ("ape", "අපේ"),
84
+ # common particles
85
+ ("one", "ඕනෙ"),
86
+ ("oney", "ඕනේ"),
87
+ ("on", "ඕනෙ"),
88
+ ("oni", "ඕනි"),
89
+ ("hari", "හරි"),
90
+ ("hri", "හරි"),
91
+ ("wage", "වගේ"),
92
+ ("nisa", "නිසා"),
93
+ ("dan", "දැන්"),
94
+ ("gena", "ගැන"),
95
+ # time
96
+ ("heta", "හෙට"),
97
+ ("hta", "හෙට"),
98
+ ("ada", "අද"),
99
+ ("iye", "ඊයේ"),
100
+ ("kalin", "කලින්"),
101
+ ("passe", "පස්සෙ"),
102
+ # abbreviations
103
+ ("mn", "මං"),
104
+ ("ek", "එක"),
105
+ ("ekta", "එකට"),
106
+ ("eke", "එකේ"),
107
+ ("me", "මේ"),
108
+ # common words
109
+ ("honda", "හොඳ"),
110
+ ("hodai", "හොඳයි"),
111
+ ("gedara", "ගෙදර"),
112
+ ("wada", "වැඩ"),
113
+ ("kema", "කෑම"),
114
+ ("kama", "කෑම"),
115
+ ("inne", "ඉන්නෙ"),
116
+ ("inna", "ඉන්න"),
117
+ ("madi", "මදි"),
118
+ ("iwara", "ඉවර"),
119
+ ("iwra", "ඉවර"),
120
+ # verbal
121
+ ("awa", "ආවා"),
122
+ ("aawa", "ආවා"),
123
+ ("giya", "ගියා"),
124
+ ("una", "උනා"),
125
+ ("wuna", "උනා"),
126
+ ("kiwa", "කිව්වා"),
127
+ ("kiwwa", "කිව්වා"),
128
+ ("yewwa", "යැව්වා"),
129
+ ("yawwa", "යැව්වා"),
130
+ ("damma", "දැම්මා"),
131
+ ("karanna", "කරන්න"),
132
+ ("krnna", "කරන්න"),
133
+ ("balanna", "බලන්න"),
134
+ ("blnna", "බලන්න"),
135
+ ("hadanna", "හදන්න"),
136
+ ("karamu", "කරමු"),
137
+ ("balamu", "බලමු"),
138
+ ("yamu", "යමු"),
139
+ ("hadamu", "හදමු"),
140
+ ("damu", "දාමු"),
141
+ ("wenawa", "වෙනව��"),
142
+ ("wenwa", "වෙනවා"),
143
+ ("thiyanawa", "තියෙනවා"),
144
+ ("enawa", "එනවා"),
145
+ ("yanawa", "යනවා"),
146
+ ]
147
+
148
+
149
+ # ── Dataset builder ───────────────────────────────────────────────────────────
150
+
151
+ def build_dataset(tokenizer) -> Dataset:
152
+ import csv
153
+
154
+ pairs: list[dict] = []
155
+
156
+ # 1. Correction pairs repeated REPEAT times
157
+ for romanized, sinhala in CORRECTIONS:
158
+ for _ in range(REPEAT):
159
+ pairs.append({"romanized": romanized, "sinhala": sinhala})
160
+
161
+ correction_count = len(pairs)
162
+ print(f" Correction pairs: {len(CORRECTIONS)} × {REPEAT} = {correction_count:,}")
163
+
164
+ # 2. Background sample from original training data
165
+ bg: list[dict] = []
166
+ with open(DATA_PATH, encoding="utf-8", newline="") as f:
167
+ reader = csv.DictReader(f)
168
+ for row in reader:
169
+ r = (row.get("romanized") or "").strip()
170
+ s = (row.get("sinhala") or "").strip()
171
+ if r and s:
172
+ bg.append({"romanized": r, "sinhala": s})
173
+
174
+ random.seed(SEED)
175
+ random.shuffle(bg)
176
+ bg = bg[:BG_SAMPLES]
177
+ pairs.extend(bg)
178
+ print(f" Background pairs: {len(bg):,}")
179
+ print(f" Total dataset : {len(pairs):,}")
180
+
181
+ random.shuffle(pairs)
182
+
183
+ ds = Dataset.from_list(pairs)
184
+
185
+ def tokenize(batch):
186
+ inputs = tokenizer(
187
+ batch["romanized"],
188
+ max_length=MAX_INPUT_LEN,
189
+ truncation=True,
190
+ padding="max_length",
191
+ )
192
+ targets = tokenizer(
193
+ batch["sinhala"],
194
+ max_length=MAX_TARGET_LEN,
195
+ truncation=True,
196
+ padding="max_length",
197
+ )
198
+ inputs["labels"] = [
199
+ [(t if t != tokenizer.pad_token_id else -100) for t in ids]
200
+ for ids in targets["input_ids"]
201
+ ]
202
+ return inputs
203
+
204
+ ds = ds.map(tokenize, batched=True, batch_size=5_000,
205
+ remove_columns=["romanized", "sinhala"], desc="Tokenizing")
206
+ ds.set_format("torch")
207
+ return ds
208
+
209
+
210
+ # ── Main ──────────────────────────────────────────────────────────────────────
211
+
212
+ def main():
213
+ device = "cuda" if torch.cuda.is_available() else "cpu"
214
+ print(f"\nDevice : {device}")
215
+ if device == "cpu":
216
+ print("WARNING: running on CPU — this will take ~30-60 min.")
217
+
218
+ print(f"Loading model from {MODEL_PATH} ...")
219
+ tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
220
+ model = AutoModelForSeq2SeqLM.from_pretrained(str(MODEL_PATH))
221
+
222
+ print("\nBuilding correction dataset ...")
223
+ ds = build_dataset(tokenizer)
224
+
225
+ split = ds.train_test_split(test_size=0.02, seed=SEED)
226
+ train_ds = split["train"]
227
+ eval_ds = split["test"]
228
+ print(f" train={len(train_ds):,} eval={len(eval_ds):,}")
229
+
230
+ warmup = max(100, len(train_ds) // (BATCH_SIZE * 20))
231
+
232
+ args = Seq2SeqTrainingArguments(
233
+ output_dir=str(OUTPUT_DIR),
234
+ num_train_epochs=EPOCHS,
235
+ per_device_train_batch_size=BATCH_SIZE,
236
+ per_device_eval_batch_size=BATCH_SIZE,
237
+ learning_rate=LR,
238
+ warmup_steps=warmup,
239
+ weight_decay=0.01,
240
+ predict_with_generate=False, # faster eval — we only care about loss
241
+ eval_strategy="epoch",
242
+ save_strategy="epoch",
243
+ load_best_model_at_end=True,
244
+ metric_for_best_model="eval_loss",
245
+ logging_steps=100,
246
+ dataloader_num_workers=0,
247
+ seed=SEED,
248
+ bf16=torch.cuda.is_bf16_supported(),
249
+ fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
250
+ )
251
+
252
+ trainer = Seq2SeqTrainer(
253
+ model=model,
254
+ args=args,
255
+ train_dataset=train_ds,
256
+ eval_dataset=eval_ds,
257
+ data_collator=default_data_collator,
258
+ )
259
+
260
+ print("\nStarting correction fine-tune ...")
261
+ trainer.train()
262
+
263
+ print(f"\nSaving corrected model to {OUTPUT_DIR} ...")
264
+ model.save_pretrained(str(OUTPUT_DIR))
265
+ tokenizer.save_pretrained(str(OUTPUT_DIR))
266
+ print("Done.")
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main()
seq2seq/infer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference helper — given a romanized word, return top-K Sinhala candidates
3
+ using beam search on the fine-tuned ByT5 model.
4
+
5
+ Usage:
6
+ from seq2seq.infer import Transliterator
7
+ t = Transliterator()
8
+ print(t.candidates("videowe", k=5))
9
+ # ['වීඩියොවේ', 'වීඩියොවී', 'වීඩියොව', ...]
10
+ """
11
+
12
+ from __future__ import annotations
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ import torch
17
+ from transformers import ByT5Tokenizer, T5ForConditionalGeneration
18
+
19
+ DEFAULT_MODEL_PATH = Path(__file__).parent / "byt5-singlish-sinhala" / "final"
20
+
21
+
22
+ class Transliterator:
23
+ def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: Optional[str] = None):
24
+ # Keep as string — Path() would convert '/' to '\' on Windows, breaking HF Hub IDs
25
+ model_path = str(model_path)
26
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.tokenizer = ByT5Tokenizer.from_pretrained(model_path)
28
+ self.model = T5ForConditionalGeneration.from_pretrained(model_path)
29
+ self.model.to(self.device)
30
+ self.model.eval()
31
+
32
+ def candidates(self, word: str, k: int = 5) -> list[str]:
33
+ """Return top-k Sinhala transliteration candidates for a single word."""
34
+ return self.batch_candidates([word], k=k)[0]
35
+
36
+ def batch_candidates(self, words: list[str], k: int = 5) -> list[list[str]]:
37
+ """
38
+ Return top-k Sinhala candidates for each word in a single forward pass.
39
+ Much faster than calling candidates() per word on a long sentence.
40
+ """
41
+ lowered = [w.lower() for w in words]
42
+ inputs = self.tokenizer(
43
+ lowered,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=64,
48
+ ).to(self.device)
49
+
50
+ n = len(words)
51
+ with torch.no_grad():
52
+ outputs = self.model.generate(
53
+ **inputs,
54
+ num_beams=max(k, 5),
55
+ num_return_sequences=k,
56
+ max_new_tokens=64,
57
+ early_stopping=True,
58
+ )
59
+
60
+ # outputs shape: (n * k, seq_len) — k sequences per input, grouped
61
+ results: list[list[str]] = []
62
+ for i in range(n):
63
+ seen: set[str] = set()
64
+ cands: list[str] = []
65
+ for seq in outputs[i * k : (i + 1) * k]:
66
+ text = self.tokenizer.decode(seq, skip_special_tokens=True).strip()
67
+ if text and text not in seen:
68
+ seen.add(text)
69
+ cands.append(text)
70
+ results.append(cands)
71
+
72
+ return results
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import sys
77
+ words = sys.argv[1:] if len(sys.argv) > 1 else ["wadi"]
78
+ t = Transliterator()
79
+ for word in words:
80
+ print(f"Candidates for '{word}':")
81
+ for c in t.candidates(word):
82
+ print(f" {c}")
83
+ print()
seq2seq/mbart_infer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mBart50-based Sentence Transliterator for SinCode v3.
3
+
4
+ Full-sentence Singlish → Sinhala transliteration.
5
+ Unlike the ByT5 word-by-word pipeline, mBart50 operates on the whole input
6
+ sentence and produces fully Sinhalized output — no English words are retained.
7
+
8
+ Use-case: "mn heta business ekak start karanawa"
9
+ → "මන් හෙට ව්‍යාපාරයක් පටන් ගන්නවා"
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import re
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ import torch
21
+ from transformers import MBart50Tokenizer, MBartForConditionalGeneration
22
+
23
+ from core.constants import DEFAULT_MBART_MODEL
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ── Fix-map (ZWJ / Virama composition) ───────────────────────────────────────
28
+
29
+ _FIX_MAP_PATH = Path(__file__).parent / "compose_fix_map.json"
30
+
31
+ _fix_map_cache: dict[str, str] | None = None
32
+
33
+
34
+ def _load_fix_map() -> dict[str, str]:
35
+ global _fix_map_cache
36
+ if _fix_map_cache is None:
37
+ with open(_FIX_MAP_PATH, "r", encoding="utf-8") as f:
38
+ _fix_map_cache = json.load(f)
39
+ return _fix_map_cache
40
+
41
+
42
+ # ── Input cleaning ────────────────────────────────────────────────────────────
43
+
44
+ # Scripts that are not Sinhala, Latin, numbers, or symbols — filtered out
45
+ _UNSUPPORTED_SCRIPT = re.compile(
46
+ r"[\u0B80-\u0BFF" # Tamil
47
+ r"\u0900-\u097F" # Devanagari
48
+ r"\u4E00-\u9FFF" # CJK Unified Ideographs
49
+ r"\u3040-\u309F" # Hiragana
50
+ r"\u30A0-\u30FF" # Katakana
51
+ r"\u0E00-\u0E7F" # Thai
52
+ r"\u0600-\u06FF" # Arabic
53
+ r"\u0590-\u05FF" # Hebrew
54
+ r"\uAC00-\uD7AF]" # Hangul
55
+ )
56
+
57
+
58
+ def _clean(text: str) -> str | None:
59
+ """Remove words in unsupported scripts; return None if nothing remains."""
60
+ words = text.strip().split()
61
+ filtered = [w for w in words if not _UNSUPPORTED_SCRIPT.search(w)]
62
+ return " ".join(filtered) if filtered else None
63
+
64
+
65
+ def _apply_fixes(text: str) -> str:
66
+ """Apply ZWJ/virama composition fixes to mBart50 output."""
67
+ for pattern, replacement in _load_fix_map().items():
68
+ text = re.sub(pattern, replacement, text)
69
+ return text
70
+
71
+
72
+ # ── Transliterator ────────────────────────────────────────────────────────────
73
+
74
+ class SentenceTransliterator:
75
+ """
76
+ Full-sentence Singlish → Sinhala transliterator (mBart50).
77
+
78
+ Loads from Hugging Face Hub on first instantiation.
79
+ Thread-safe for inference (no mutable state after __init__).
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ model_name: str = DEFAULT_MBART_MODEL,
85
+ device: Optional[str] = None,
86
+ ):
87
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ logger.info("Loading mBart50 transliterator: %s", model_name)
90
+ self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
91
+ self.model = MBartForConditionalGeneration.from_pretrained(model_name)
92
+ self.model.to(self.device)
93
+ self.model.eval()
94
+
95
+ def transliterate(self, text: str) -> str:
96
+ """
97
+ Transliterate a Singlish sentence to fully-Sinhalized output.
98
+
99
+ Args:
100
+ text: Input Singlish sentence (Romanized Sinhala / English mix).
101
+
102
+ Returns:
103
+ Sinhala-script output. Returns original text if input is empty
104
+ or consists entirely of unsupported-script characters.
105
+ """
106
+ cleaned = _clean(text)
107
+ if not cleaned:
108
+ return text
109
+
110
+ self.tokenizer.src_lang = "si_LK"
111
+ inputs = self.tokenizer(
112
+ cleaned,
113
+ return_tensors="pt",
114
+ padding=True,
115
+ truncation=True,
116
+ max_length=128,
117
+ ).to(self.device)
118
+
119
+ with torch.no_grad():
120
+ tokens = self.model.generate(
121
+ **inputs,
122
+ forced_bos_token_id=self.tokenizer.lang_code_to_id["si_LK"],
123
+ )
124
+
125
+ output = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
126
+ return _apply_fixes(output)
seq2seq/prepare_data.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parse WSD.txt into a CSV training dataset for ByT5 fine-tuning.
3
+
4
+ Input format (WSD.txt):
5
+ Word: <romanized>, Sinhala Words: ['<s1>', '<s2>', ...]
6
+
7
+ Output (wsd_pairs.csv):
8
+ romanized,sinhala
9
+ wadi,වෑඩි
10
+ wadi,වාඩි
11
+ ...
12
+
13
+ One row per (romanized, sinhala) pair. Duplicate sinhala entries per
14
+ word are kept since ByT5 learns from all valid transliterations.
15
+ """
16
+
17
+ import ast
18
+ import csv
19
+ import re
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ WSD_PATH = Path(r"C:\Y5_Docs\FYP\WSD.txt")
24
+ OUT_PATH = Path(__file__).parent / "wsd_pairs.csv"
25
+
26
+ LINE_RE = re.compile(r"^Word:\s*(.+?),\s*Sinhala Words:\s*(\[.+\])\s*$")
27
+
28
+ MIN_ROMAN_LEN = 2 # skip single-char romanized entries
29
+ MAX_ROMAN_LEN = 40 # skip obviously malformed long entries
30
+
31
+
32
+ def parse_wsd(wsd_path: Path) -> list[tuple[str, str]]:
33
+ pairs: list[tuple[str, str]] = []
34
+ skipped = 0
35
+
36
+ with wsd_path.open(encoding="utf-8") as f:
37
+ for lineno, line in enumerate(f, 1):
38
+ line = line.strip()
39
+ if not line:
40
+ continue
41
+
42
+ m = LINE_RE.match(line)
43
+ if not m:
44
+ skipped += 1
45
+ continue
46
+
47
+ roman = m.group(1).strip().lower()
48
+ if not (MIN_ROMAN_LEN <= len(roman) <= MAX_ROMAN_LEN):
49
+ skipped += 1
50
+ continue
51
+
52
+ try:
53
+ sinhala_list = ast.literal_eval(m.group(2))
54
+ except (ValueError, SyntaxError):
55
+ skipped += 1
56
+ continue
57
+
58
+ for sinhala in sinhala_list:
59
+ sinhala = sinhala.strip()
60
+ if sinhala:
61
+ pairs.append((roman, sinhala))
62
+
63
+ if lineno % 100_000 == 0:
64
+ print(f" processed {lineno:,} lines, {len(pairs):,} pairs so far…")
65
+
66
+ print(f" skipped {skipped:,} malformed lines")
67
+ return pairs
68
+
69
+
70
+ def write_csv(pairs: list[tuple[str, str]], out_path: Path) -> None:
71
+ out_path.parent.mkdir(parents=True, exist_ok=True)
72
+ with out_path.open("w", encoding="utf-8", newline="") as f:
73
+ writer = csv.writer(f)
74
+ writer.writerow(["romanized", "sinhala"])
75
+ writer.writerows(pairs)
76
+
77
+
78
+ def main() -> None:
79
+ print(f"Parsing {WSD_PATH} …")
80
+ pairs = parse_wsd(WSD_PATH)
81
+ print(f"\nTotal pairs: {len(pairs):,}")
82
+
83
+ print(f"Writing to {OUT_PATH} …")
84
+ write_csv(pairs, OUT_PATH)
85
+ print("Done.")
86
+
87
+ # Quick sanity check
88
+ print("\nSample rows:")
89
+ for roman, sinhala in pairs[:5]:
90
+ print(f" {roman!r:20s} → {sinhala}")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
seq2seq/train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tune google/byt5-small on Singlish → Sinhala word-level transliteration.
3
+
4
+ Input: wsd_pairs.csv (romanized, sinhala)
5
+ Output: byt5-singlish-sinhala/ (HuggingFace model directory)
6
+
7
+ Training approach:
8
+ - Input : romanized word (e.g. "wadi")
9
+ - Target : sinhala word (e.g. "වැඩි")
10
+ - Model : ByT5-small (byte-level T5, no vocab issues with any script)
11
+ - Beam=5 at inference → top-5 candidates for MLM reranking
12
+
13
+ Tokenized dataset is saved to disk after first run — restarts skip
14
+ straight to training without re-tokenizing.
15
+ """
16
+
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ from datasets import Dataset, load_from_disk
21
+ from transformers import (
22
+ AutoTokenizer,
23
+ AutoModelForSeq2SeqLM,
24
+ Seq2SeqTrainer,
25
+ Seq2SeqTrainingArguments,
26
+ default_data_collator,
27
+ )
28
+
29
+ # ── Config ─────────────────────────────────────────────────────────────────
30
+
31
+ BASE_MODEL = "google/byt5-small"
32
+ DATA_PATH = Path(__file__).parent / "wsd_pairs.csv"
33
+ CACHE_DIR = Path(__file__).parent / "tokenized_cache"
34
+ OUTPUT_DIR = Path(__file__).parent / "byt5-singlish-sinhala"
35
+
36
+ MAX_SAMPLES = 1_000_000 # 1M pairs — more than enough for word transliteration
37
+ TRAIN_SPLIT = 0.97
38
+ MAX_INPUT_LEN = 64
39
+ MAX_TARGET_LEN = 64
40
+ BATCH_SIZE = 64 # 16GB VRAM — ByT5-small with seq_len=64
41
+ EPOCHS = 2
42
+ LR = 5e-4
43
+ SEED = 42
44
+
45
+
46
+ # ── Tokenize ────────────────────────────────────────────────────────────────
47
+
48
+ def tokenize_fn(batch, tokenizer):
49
+ # Pad to fixed max_length so all tensors have the same shape.
50
+ # This lets set_format("torch") work and default_data_collator just stacks.
51
+ model_inputs = tokenizer(
52
+ batch["romanized"],
53
+ max_length=MAX_INPUT_LEN,
54
+ truncation=True,
55
+ padding="max_length",
56
+ )
57
+ labels = tokenizer(
58
+ batch["sinhala"],
59
+ max_length=MAX_TARGET_LEN,
60
+ truncation=True,
61
+ padding="max_length",
62
+ )
63
+ # Replace pad token with -100 so it's ignored in cross-entropy loss
64
+ model_inputs["labels"] = [
65
+ [(t if t != tokenizer.pad_token_id else -100) for t in ids]
66
+ for ids in labels["input_ids"]
67
+ ]
68
+ return model_inputs
69
+
70
+
71
+ # ── Main ───────────────────────────────────────────────────────────────────
72
+
73
+ def main():
74
+ import os
75
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
76
+
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ print(f" Device : {device}")
79
+ if device == "cuda":
80
+ print(f" GPU : {torch.cuda.get_device_name(0)}")
81
+ print(f" VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
82
+ else:
83
+ print(" WARNING: No GPU detected — training will be very slow!")
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
86
+ model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
87
+
88
+ train_cache = CACHE_DIR / "train"
89
+ eval_cache = CACHE_DIR / "eval"
90
+
91
+ if train_cache.exists() and eval_cache.exists():
92
+ print("Loading pre-tokenized dataset from disk cache …")
93
+ train_ds = load_from_disk(str(train_cache))
94
+ eval_ds = load_from_disk(str(eval_cache))
95
+ print(f" train={len(train_ds):,} eval={len(eval_ds):,}")
96
+ else:
97
+ print(f"Loading data from {DATA_PATH} …")
98
+ ds = Dataset.from_csv(str(DATA_PATH))
99
+ ds = ds.filter(lambda x: bool(x["romanized"]) and bool(x["sinhala"]))
100
+ print(f" {len(ds):,} pairs — sampling {MAX_SAMPLES:,} …")
101
+
102
+ # Shuffle and take MAX_SAMPLES
103
+ ds = ds.shuffle(seed=SEED).select(range(min(MAX_SAMPLES, len(ds))))
104
+
105
+ split = ds.train_test_split(test_size=1 - TRAIN_SPLIT, seed=SEED)
106
+ train_raw = split["train"]
107
+ eval_raw = split["test"]
108
+ print(f" train={len(train_raw):,} eval={len(eval_raw):,}")
109
+
110
+ print("Tokenizing and saving to disk (one-time, ~5 min) …")
111
+ train_ds = train_raw.map(
112
+ lambda b: tokenize_fn(b, tokenizer),
113
+ batched=True,
114
+ batch_size=10_000,
115
+ num_proc=8,
116
+ keep_in_memory=True,
117
+ remove_columns=["romanized", "sinhala"],
118
+ desc="Tokenizing train",
119
+ )
120
+ eval_ds = eval_raw.map(
121
+ lambda b: tokenize_fn(b, tokenizer),
122
+ batched=True,
123
+ batch_size=10_000,
124
+ num_proc=8,
125
+ keep_in_memory=True,
126
+ remove_columns=["romanized", "sinhala"],
127
+ desc="Tokenizing eval",
128
+ )
129
+
130
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
131
+ train_ds.save_to_disk(str(train_cache))
132
+ eval_ds.save_to_disk(str(eval_cache))
133
+ print(" Saved to disk. Future runs will load instantly.")
134
+
135
+ train_ds.set_format("torch")
136
+ eval_ds.set_format("torch")
137
+
138
+ # All sequences are pre-padded to fixed length — just stack them
139
+ collator = default_data_collator
140
+ warmup_steps = int(0.05 * (len(train_ds) // BATCH_SIZE))
141
+
142
+ args = Seq2SeqTrainingArguments(
143
+ output_dir=str(OUTPUT_DIR),
144
+ num_train_epochs=EPOCHS,
145
+ per_device_train_batch_size=BATCH_SIZE,
146
+ per_device_eval_batch_size=BATCH_SIZE,
147
+ learning_rate=LR,
148
+ warmup_steps=warmup_steps,
149
+ weight_decay=0.01,
150
+ predict_with_generate=True,
151
+ eval_strategy="epoch",
152
+ save_strategy="epoch",
153
+ load_best_model_at_end=True,
154
+ metric_for_best_model="eval_loss",
155
+ logging_steps=200,
156
+ dataloader_num_workers=0, # 0 = main process only (most stable on Windows)
157
+ dataloader_pin_memory=True,
158
+ bf16=torch.cuda.is_bf16_supported(),
159
+ fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
160
+ seed=SEED,
161
+ report_to="none",
162
+ )
163
+
164
+ trainer = Seq2SeqTrainer(
165
+ model=model,
166
+ args=args,
167
+ train_dataset=train_ds,
168
+ eval_dataset=eval_ds,
169
+ processing_class=tokenizer,
170
+ data_collator=collator,
171
+ )
172
+
173
+ print("Starting training …")
174
+ trainer.train()
175
+
176
+ print(f"Saving model to {OUTPUT_DIR}/final …")
177
+ model.save_pretrained(OUTPUT_DIR / "final")
178
+ tokenizer.save_pretrained(OUTPUT_DIR / "final")
179
+ print("Done.")
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()
184
+
185
+
sincode_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SinCode v3 — public API entry point.
3
+
4
+ Usage:
5
+ from sincode_model import BeamSearchDecoder
6
+ decoder = BeamSearchDecoder()
7
+ result, logs = decoder.decode("mema videowe bit rate eka godak wadi nisa buffer wenawa")
8
+ """
9
+
10
+ from core.decoder import BeamSearchDecoder, ScoredCandidate # noqa: F401
11
+ from core.english import ENGLISH_VOCAB # noqa: F401
12
+ from core.constants import ( # noqa: F401
13
+ DEFAULT_MLM_MODEL, DEFAULT_BYT5_MODEL, DEFAULT_MBART_MODEL,
14
+ MAX_CANDIDATES, MIN_ENGLISH_LEN,
15
+ )
16
+ from seq2seq.mbart_infer import SentenceTransliterator # noqa: F401