KalanaPabasara commited on
Commit ·
f6f45d5
0
Parent(s):
SinCode v3 — ByT5 seq2seq + XLM-RoBERTa MLM reranker
Browse files- .gitignore +28 -0
- app.py +75 -0
- architecture.html +111 -0
- architecture.mmd +72 -0
- core/__init__.py +0 -0
- core/constants.py +35 -0
- core/decoder.py +248 -0
- core/english.py +73 -0
- core/mappings.py +8 -0
- english_20k.txt +0 -0
- misc/dataset_110.csv +111 -0
- misc/dataset_40.csv +41 -0
- misc/evaluate.py +446 -0
- misc/upload_mlm_to_hf.py +50 -0
- misc/v2_baseline.json +10 -0
- requirements.txt +6 -0
- seq2seq/__init__.py +1 -0
- seq2seq/compose_fix_map.json +46 -0
- seq2seq/finetune_corrections.py +270 -0
- seq2seq/infer.py +83 -0
- seq2seq/mbart_infer.py +126 -0
- seq2seq/prepare_data.py +94 -0
- seq2seq/train.py +185 -0
- sincode_model.py +16 -0
.gitignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual environment
|
| 2 |
+
.venv/
|
| 3 |
+
|
| 4 |
+
# Model weights — hosted on Hugging Face Hub
|
| 5 |
+
seq2seq/byt5-singlish-sinhala/
|
| 6 |
+
seq2seq/tokenized_cache/
|
| 7 |
+
|
| 8 |
+
# Large training data
|
| 9 |
+
seq2seq/wsd_pairs.csv
|
| 10 |
+
|
| 11 |
+
# Backup weights
|
| 12 |
+
*.safetensors.bak
|
| 13 |
+
*_backup.safetensors
|
| 14 |
+
final_pre_correction_backup.safetensors
|
| 15 |
+
|
| 16 |
+
# Python cache
|
| 17 |
+
__pycache__/
|
| 18 |
+
*.pyc
|
| 19 |
+
*.pyo
|
| 20 |
+
*.pyd
|
| 21 |
+
.Python
|
| 22 |
+
|
| 23 |
+
# Evaluation outputs (optional — remove if you want these tracked)
|
| 24 |
+
misc/v3_results_110.csv
|
| 25 |
+
|
| 26 |
+
# Misc
|
| 27 |
+
*.log
|
| 28 |
+
.DS_Store
|
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SinCode v3 — Streamlit demo UI.
|
| 3 |
+
|
| 4 |
+
Architecture: ByT5-small (seq2seq candidate generation) +
|
| 5 |
+
XLM-RoBERTa (MLM contextual reranking)
|
| 6 |
+
|
| 7 |
+
Two transliteration modes:
|
| 8 |
+
• Code-Mixed — ByT5 + MLM; retains English words where contextually apt
|
| 9 |
+
• Full Sinhala — mBart50 sentence-level; transliterates everything to Sinhala
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import streamlit as st
|
| 13 |
+
from sincode_model import BeamSearchDecoder, SentenceTransliterator
|
| 14 |
+
|
| 15 |
+
st.set_page_config(page_title="සිංCode v3", page_icon="🇱🇰", layout="centered")
|
| 16 |
+
|
| 17 |
+
st.title("සිංCode v3")
|
| 18 |
+
st.caption("ByT5 seq2seq + XLM-RoBERTa MLM reranking")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@st.cache_resource(show_spinner="Loading models (ByT5 + XLM-RoBERTa)…")
|
| 22 |
+
def load_decoder() -> BeamSearchDecoder:
|
| 23 |
+
return BeamSearchDecoder()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@st.cache_resource(show_spinner="Loading mBart50 model…")
|
| 27 |
+
def load_transliterator() -> SentenceTransliterator:
|
| 28 |
+
return SentenceTransliterator()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
mode = st.radio(
|
| 32 |
+
"Transliteration mode",
|
| 33 |
+
options=["Code-Mixed Output", "Full Sinhala Output"],
|
| 34 |
+
horizontal=True,
|
| 35 |
+
help=(
|
| 36 |
+
"**Code-Mixed**: keeps English technical/borrowed words where natural "
|
| 37 |
+
"(e.g. *buffer*, *bit rate*). "
|
| 38 |
+
"**Full Sinhala**: transliterates every word to Sinhala script "
|
| 39 |
+
"(e.g. *business* → ව්යාපාරය)."
|
| 40 |
+
),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
sentence = st.text_input(
|
| 44 |
+
"Enter Singlish sentence",
|
| 45 |
+
placeholder="e.g. mema videowe bit rate eka godak wadi nisa buffer wenawa",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
show_trace = st.checkbox(
|
| 49 |
+
"Show step-by-step trace",
|
| 50 |
+
value=False,
|
| 51 |
+
disabled=(mode == "Full Sinhala Output"),
|
| 52 |
+
help="Trace is only available in Code-Mixed mode.",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if st.button("Transliterate", type="primary") and sentence.strip():
|
| 56 |
+
if mode == "Full Sinhala Output":
|
| 57 |
+
with st.spinner("Transliterating (mBart50)…"):
|
| 58 |
+
transliterator = load_transliterator()
|
| 59 |
+
result = transliterator.transliterate(sentence.strip())
|
| 60 |
+
|
| 61 |
+
st.markdown("### Result")
|
| 62 |
+
st.success(result)
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
with st.spinner("Transliterating…"):
|
| 66 |
+
decoder = load_decoder()
|
| 67 |
+
result, trace_logs = decoder.decode(sentence.strip())
|
| 68 |
+
|
| 69 |
+
st.markdown("### Result")
|
| 70 |
+
st.success(result)
|
| 71 |
+
|
| 72 |
+
if show_trace:
|
| 73 |
+
st.markdown("### Trace")
|
| 74 |
+
for log in trace_logs:
|
| 75 |
+
st.markdown(log)
|
architecture.html
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<title>SinCode v3 — Architecture</title>
|
| 6 |
+
<script src="https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.min.js"></script>
|
| 7 |
+
<style>
|
| 8 |
+
body {
|
| 9 |
+
font-family: sans-serif;
|
| 10 |
+
background: #f8f9fa;
|
| 11 |
+
display: flex;
|
| 12 |
+
flex-direction: column;
|
| 13 |
+
align-items: center;
|
| 14 |
+
padding: 2rem;
|
| 15 |
+
}
|
| 16 |
+
h1 { color: #2c3e50; margin-bottom: 0.25rem; }
|
| 17 |
+
p { color: #666; margin-top: 0; margin-bottom: 2rem; }
|
| 18 |
+
.mermaid {
|
| 19 |
+
background: white;
|
| 20 |
+
border-radius: 12px;
|
| 21 |
+
padding: 2rem;
|
| 22 |
+
box-shadow: 0 2px 12px rgba(0,0,0,0.08);
|
| 23 |
+
max-width: 1200px;
|
| 24 |
+
width: 100%;
|
| 25 |
+
}
|
| 26 |
+
</style>
|
| 27 |
+
</head>
|
| 28 |
+
<body>
|
| 29 |
+
<h1>SinCode v3 — System Architecture</h1>
|
| 30 |
+
<p>ByT5-small · XLM-RoBERTa · mBart50-large</p>
|
| 31 |
+
|
| 32 |
+
<div class="mermaid">
|
| 33 |
+
flowchart TD
|
| 34 |
+
UI["🖥️ Streamlit UI\napp.py"]
|
| 35 |
+
MODE{Mode?}
|
| 36 |
+
|
| 37 |
+
UI --> MODE
|
| 38 |
+
|
| 39 |
+
subgraph MODE_FULL["Full Sinhala Mode"]
|
| 40 |
+
direction TB
|
| 41 |
+
ST["SentenceTransliterator\nseq2seq/mbart_infer.py"]
|
| 42 |
+
MBART["mBart50-large\nKalana001/mbart50-large-singlish-sinhala\nHF Hub · 2.4 GB"]
|
| 43 |
+
FIX["Compose Fix Map\nseq2seq/Compose_fix_map.json\nZWJ / Virama corrections"]
|
| 44 |
+
ST --> MBART
|
| 45 |
+
MBART -->|"raw Sinhala output"| FIX
|
| 46 |
+
end
|
| 47 |
+
|
| 48 |
+
subgraph MODE_MIXED["Code-Mixed Mode"]
|
| 49 |
+
direction TB
|
| 50 |
+
|
| 51 |
+
subgraph PHASE1["Phase 1 · Word Classification"]
|
| 52 |
+
direction LR
|
| 53 |
+
P1A["Sinhala script?\n(U+0D80–0DFF)"]
|
| 54 |
+
P1B["English vocab?\nenglish_20k.txt"]
|
| 55 |
+
P1C["Singlish\n(everything else)"]
|
| 56 |
+
end
|
| 57 |
+
|
| 58 |
+
subgraph PHASE2["Phase 2 · Candidate Generation (single ByT5 batch)"]
|
| 59 |
+
direction LR
|
| 60 |
+
BYT5["ByT5-small\nKalana001/byt5-small-singlish-sinhala\nHF Hub · 1.2 GB\nbeam=5 → top-5 candidates"]
|
| 61 |
+
SIN_PASS["Single candidate\n(word as-is)"]
|
| 62 |
+
ENG_CAND["English word\n+ ByT5 Sinhala alternatives"]
|
| 63 |
+
SIN_CAND["Top-5 ByT5\ncandidates"]
|
| 64 |
+
end
|
| 65 |
+
|
| 66 |
+
subgraph PHASE3["Phase 3 · Two-Pass MLM Reranking"]
|
| 67 |
+
direction LR
|
| 68 |
+
GREEDY["Pass 1 – Greedy\nBuild draft sentence\n(stale right context)"]
|
| 69 |
+
RESCORE["Pass 2 – Rescore\nActual decoded output\nas right context"]
|
| 70 |
+
MLM["XLM-RoBERTa\nKalana001/xlm-roberta-base-finetuned-sinhala\nHF Hub\nMulti-mask log-probability"]
|
| 71 |
+
SOFTMAX["Softmax normalise\npick argmax"]
|
| 72 |
+
end
|
| 73 |
+
|
| 74 |
+
PHASE1 --> PHASE2
|
| 75 |
+
P1A -->|Sinhala| SIN_PASS
|
| 76 |
+
P1B -->|English| ENG_CAND
|
| 77 |
+
P1C -->|Singlish| SIN_CAND
|
| 78 |
+
BYT5 --> ENG_CAND
|
| 79 |
+
BYT5 --> SIN_CAND
|
| 80 |
+
PHASE2 --> PHASE3
|
| 81 |
+
GREEDY --> MLM
|
| 82 |
+
MLM --> SOFTMAX
|
| 83 |
+
SOFTMAX --> RESCORE
|
| 84 |
+
RESCORE --> MLM
|
| 85 |
+
end
|
| 86 |
+
|
| 87 |
+
MODE -->|"Full Sinhala Output"| MODE_FULL
|
| 88 |
+
MODE -->|"Code-Mixed Output"| MODE_MIXED
|
| 89 |
+
|
| 90 |
+
MODE_FULL --> OUT["✅ Sinhala Output"]
|
| 91 |
+
MODE_MIXED --> OUT
|
| 92 |
+
|
| 93 |
+
subgraph MODELS["Models on Hugging Face Hub (Kalana001)"]
|
| 94 |
+
HF1["byt5-small-singlish-sinhala\n1.2 GB · ByT5-small"]
|
| 95 |
+
HF2["xlm-roberta-base-finetuned-sinhala\nXLM-RoBERTa"]
|
| 96 |
+
HF3["mbart50-large-singlish-sinhala\n2.4 GB · mBart50-large"]
|
| 97 |
+
end
|
| 98 |
+
|
| 99 |
+
style MODE_FULL fill:#e8f4fd,stroke:#4a9eda
|
| 100 |
+
style MODE_MIXED fill:#fdf3e8,stroke:#e8974a
|
| 101 |
+
style PHASE1 fill:#fff9e6,stroke:#cca800
|
| 102 |
+
style PHASE2 fill:#e8fff0,stroke:#2ecc71
|
| 103 |
+
style PHASE3 fill:#f4e8ff,stroke:#9b59b6
|
| 104 |
+
style MODELS fill:#eaf4ee,stroke:#27ae60
|
| 105 |
+
</div>
|
| 106 |
+
|
| 107 |
+
<script>
|
| 108 |
+
mermaid.initialize({ startOnLoad: true, theme: 'default', flowchart: { curve: 'basis' } });
|
| 109 |
+
</script>
|
| 110 |
+
</body>
|
| 111 |
+
</html>
|
architecture.mmd
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flowchart TD
|
| 2 |
+
UI["🖥️ Streamlit UI\napp.py"]
|
| 3 |
+
MODE{Mode?}
|
| 4 |
+
|
| 5 |
+
UI --> MODE
|
| 6 |
+
|
| 7 |
+
subgraph MODE_FULL["Full Sinhala Mode"]
|
| 8 |
+
direction TB
|
| 9 |
+
ST["SentenceTransliterator\nseq2seq/mbart_infer.py"]
|
| 10 |
+
MBART["mBart50-large\nKalana001/mbart50-large-singlish-sinhala\nHF Hub · 2.4 GB"]
|
| 11 |
+
FIX["Compose Fix Map\nseq2seq/Compose_fix_map.json\nZWJ / Virama corrections"]
|
| 12 |
+
ST --> MBART
|
| 13 |
+
MBART -->|"raw Sinhala output"| FIX
|
| 14 |
+
end
|
| 15 |
+
|
| 16 |
+
subgraph MODE_MIXED["Code-Mixed Mode"]
|
| 17 |
+
direction TB
|
| 18 |
+
|
| 19 |
+
subgraph PHASE1["Phase 1 · Word Classification"]
|
| 20 |
+
direction LR
|
| 21 |
+
P1A["Sinhala script?\n(U+0D80–0DFF)"]
|
| 22 |
+
P1B["English vocab?\nenglish_20k.txt"]
|
| 23 |
+
P1C["Singlish\n(everything else)"]
|
| 24 |
+
end
|
| 25 |
+
|
| 26 |
+
subgraph PHASE2["Phase 2 · Candidate Generation (single ByT5 batch)"]
|
| 27 |
+
direction LR
|
| 28 |
+
BYT5["ByT5-small\nKalana001/byt5-small-singlish-sinhala\nHF Hub · 1.2 GB\nbeam=5 → top-5 candidates"]
|
| 29 |
+
SIN_PASS["Single candidate\n(word as-is)"]
|
| 30 |
+
ENG_CAND["English word\n+ ByT5 Sinhala alternatives"]
|
| 31 |
+
SIN_CAND["Top-5 ByT5\ncandidates"]
|
| 32 |
+
end
|
| 33 |
+
|
| 34 |
+
subgraph PHASE3["Phase 3 · Two-Pass MLM Reranking"]
|
| 35 |
+
direction LR
|
| 36 |
+
GREEDY["Pass 1 – Greedy\nBuild draft sentence\n(stale right context)"]
|
| 37 |
+
RESCORE["Pass 2 – Rescore\nActual decoded output\nas right context"]
|
| 38 |
+
MLM["XLM-RoBERTa\nKalana001/xlm-roberta-base-finetuned-sinhala\nHF Hub\nMulti-mask log-probability"]
|
| 39 |
+
SOFTMAX["Softmax normalise\npick argmax"]
|
| 40 |
+
end
|
| 41 |
+
|
| 42 |
+
PHASE1 --> PHASE2
|
| 43 |
+
P1A -->|Sinhala| SIN_PASS
|
| 44 |
+
P1B -->|English| ENG_CAND
|
| 45 |
+
P1C -->|Singlish| SIN_CAND
|
| 46 |
+
BYT5 --> ENG_CAND
|
| 47 |
+
BYT5 --> SIN_CAND
|
| 48 |
+
PHASE2 --> PHASE3
|
| 49 |
+
GREEDY --> MLM
|
| 50 |
+
MLM --> SOFTMAX
|
| 51 |
+
SOFTMAX --> RESCORE
|
| 52 |
+
RESCORE --> MLM
|
| 53 |
+
end
|
| 54 |
+
|
| 55 |
+
MODE -->|"Full Sinhala Output"| MODE_FULL
|
| 56 |
+
MODE -->|"Code-Mixed Output"| MODE_MIXED
|
| 57 |
+
|
| 58 |
+
MODE_FULL --> OUT["✅ Sinhala Output"]
|
| 59 |
+
MODE_MIXED --> OUT
|
| 60 |
+
|
| 61 |
+
subgraph MODELS["Models on Hugging Face Hub (Kalana001)"]
|
| 62 |
+
HF1["byt5-small-singlish-sinhala\n1.2 GB · ByT5-small"]
|
| 63 |
+
HF2["xlm-roberta-base-finetuned-sinhala\nXLM-RoBERTa"]
|
| 64 |
+
HF3["mbart50-large-singlish-sinhala\n2.4 GB · mBart50-large"]
|
| 65 |
+
end
|
| 66 |
+
|
| 67 |
+
style MODE_FULL fill:#e8f4fd,stroke:#4a9eda
|
| 68 |
+
style MODE_MIXED fill:#fdf3e8,stroke:#e8974a
|
| 69 |
+
style PHASE1 fill:#fff9e6,stroke:#cca800
|
| 70 |
+
style PHASE2 fill:#e8fff0,stroke:#2ecc71
|
| 71 |
+
style PHASE3 fill:#f4e8ff,stroke:#9b59b6
|
| 72 |
+
style MODELS fill:#eaf4ee,stroke:#27ae60
|
core/__init__.py
ADDED
|
File without changes
|
core/constants.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration constants for SinCode v3.
|
| 3 |
+
|
| 4 |
+
Key difference from v2: no rule engine, no dictionary.
|
| 5 |
+
Candidate generation is fully handled by the ByT5 seq2seq model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
# ─── MLM Model Path ──────────────────────────────────────────────────────────
|
| 11 |
+
# XLM-RoBERTa fine-tuned on Sinhala — reranks ByT5 candidates by context
|
| 12 |
+
DEFAULT_MLM_MODEL = "Kalana001/xlm-roberta-base-finetuned-sinhala"
|
| 13 |
+
|
| 14 |
+
# ─── ByT5 Transliterator Model Path ──────────────────────────────────────────
|
| 15 |
+
# Fine-tuned on 1M Singlish→Sinhala pairs — hosted on Hugging Face Hub
|
| 16 |
+
DEFAULT_BYT5_MODEL = "Kalana001/byt5-small-singlish-sinhala"
|
| 17 |
+
|
| 18 |
+
# ─── mBart50 Transliterator Model Path ───────────────────────────────────────
|
| 19 |
+
# Full-sentence Singlish→Sinhala (no English retained) — Hugging Face Hub
|
| 20 |
+
DEFAULT_MBART_MODEL = "Kalana001/mbart50-large-singlish-sinhala"
|
| 21 |
+
|
| 22 |
+
# ─── Corpus ───────────────────────────────────────────────────────────────────
|
| 23 |
+
ENGLISH_CORPUS_URL = (
|
| 24 |
+
"https://raw.githubusercontent.com/first20hours/google-10000-english/master/20k.txt"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# ─── Scoring Weights ─────────────────────────────────────────────────────────
|
| 28 |
+
# Pure MLM — no manual weights needed
|
| 29 |
+
|
| 30 |
+
# ─── Decoding Parameters ─────────────────────────────────────────────────────
|
| 31 |
+
MAX_CANDIDATES: int = 5 # ByT5 beam=5 → 5 candidates per word
|
| 32 |
+
MIN_ENGLISH_LEN: int = 3 # Min word length for English detection
|
| 33 |
+
|
| 34 |
+
# ─── Regex ───────────────────────────────────────────────────────────────────
|
| 35 |
+
PUNCT_PATTERN = re.compile(r"^(\W*)(.*?)(\W*)$")
|
core/decoder.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SinCode v3 — ByT5 Seq2Seq + XLM-RoBERTa MLM Reranker.
|
| 3 |
+
|
| 4 |
+
Pipeline (per word):
|
| 5 |
+
Sinhala script → MLM scores in context (single candidate)
|
| 6 |
+
English vocab → ByT5 generates Sinhala alternatives + English kept; MLM picks
|
| 7 |
+
Everything else → ByT5 generates top-5 candidates; MLM picks best
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import re
|
| 12 |
+
import torch
|
| 13 |
+
import logging
|
| 14 |
+
from typing import List, Tuple, Optional
|
| 15 |
+
|
| 16 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 17 |
+
|
| 18 |
+
from core.constants import (
|
| 19 |
+
DEFAULT_MLM_MODEL, DEFAULT_BYT5_MODEL,
|
| 20 |
+
MAX_CANDIDATES, MIN_ENGLISH_LEN,
|
| 21 |
+
PUNCT_PATTERN,
|
| 22 |
+
)
|
| 23 |
+
from core.english import ENGLISH_VOCAB
|
| 24 |
+
from seq2seq.infer import Transliterator
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
_SINHALA_RE = re.compile(r"[\u0D80-\u0DFF]")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ScoredCandidate:
|
| 32 |
+
__slots__ = ("text", "mlm_score")
|
| 33 |
+
|
| 34 |
+
def __init__(self, text: str, mlm_score: float):
|
| 35 |
+
self.text = text
|
| 36 |
+
self.mlm_score = mlm_score
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _is_sinhala(text: str) -> bool:
|
| 40 |
+
return bool(_SINHALA_RE.search(text))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BeamSearchDecoder:
|
| 44 |
+
"""
|
| 45 |
+
SinCode v3 contextual decoder.
|
| 46 |
+
|
| 47 |
+
Replaces the rule engine + dictionary + hardcoded maps with a single
|
| 48 |
+
ByT5-small seq2seq model fine-tuned on 1,000,000 Singlish→Sinhala pairs.
|
| 49 |
+
XLM-RoBERTa reranks the top-5 beam candidates by masked-LM probability.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
mlm_model_name: str = DEFAULT_MLM_MODEL,
|
| 55 |
+
byt5_model_path: str = DEFAULT_BYT5_MODEL,
|
| 56 |
+
device: Optional[str] = None,
|
| 57 |
+
):
|
| 58 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 59 |
+
|
| 60 |
+
logger.info("Loading MLM reranker: %s", mlm_model_name)
|
| 61 |
+
self.tokenizer = AutoTokenizer.from_pretrained(mlm_model_name)
|
| 62 |
+
self.model = AutoModelForMaskedLM.from_pretrained(mlm_model_name)
|
| 63 |
+
self.model.to(self.device)
|
| 64 |
+
self.model.eval()
|
| 65 |
+
|
| 66 |
+
logger.info("Loading ByT5 transliterator: %s", byt5_model_path)
|
| 67 |
+
self.transliterator = Transliterator(model_path=byt5_model_path, device=self.device)
|
| 68 |
+
|
| 69 |
+
# ── Normalization ─────────────────────────────────────────────────────────
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def _softmax_normalize(raw_scores: List[float]) -> List[float]:
|
| 73 |
+
if not raw_scores:
|
| 74 |
+
return []
|
| 75 |
+
if len(raw_scores) == 1:
|
| 76 |
+
return [1.0]
|
| 77 |
+
max_s = max(raw_scores)
|
| 78 |
+
exps = [math.exp(s - max_s) for s in raw_scores]
|
| 79 |
+
total = sum(exps)
|
| 80 |
+
return [e / total for e in exps]
|
| 81 |
+
|
| 82 |
+
# ── MLM batch scoring ─────────────────────────────────────────────────────
|
| 83 |
+
|
| 84 |
+
def _batch_mlm_score(
|
| 85 |
+
self,
|
| 86 |
+
left_contexts: List[str],
|
| 87 |
+
right_contexts: List[str],
|
| 88 |
+
candidates: List[str],
|
| 89 |
+
) -> List[float]:
|
| 90 |
+
"""Score each candidate with XLM-RoBERTa multi-mask log-probability."""
|
| 91 |
+
if not candidates:
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
mask = self.tokenizer.mask_token
|
| 95 |
+
mask_token_id = self.tokenizer.mask_token_id
|
| 96 |
+
|
| 97 |
+
cand_token_ids: List[List[int]] = []
|
| 98 |
+
for c in candidates:
|
| 99 |
+
ids = self.tokenizer.encode(c, add_special_tokens=False)
|
| 100 |
+
cand_token_ids.append(ids if ids else [self.tokenizer.unk_token_id])
|
| 101 |
+
|
| 102 |
+
batch_texts: List[str] = []
|
| 103 |
+
for i in range(len(candidates)):
|
| 104 |
+
n_masks = len(cand_token_ids[i])
|
| 105 |
+
mask_str = " ".join([mask] * n_masks)
|
| 106 |
+
parts = [p for p in [left_contexts[i], mask_str, right_contexts[i]] if p]
|
| 107 |
+
batch_texts.append(" ".join(parts))
|
| 108 |
+
|
| 109 |
+
inputs = self.tokenizer(
|
| 110 |
+
batch_texts,
|
| 111 |
+
return_tensors="pt",
|
| 112 |
+
padding=True,
|
| 113 |
+
truncation=True,
|
| 114 |
+
).to(self.device)
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
logits = self.model(**inputs).logits
|
| 118 |
+
|
| 119 |
+
scores: List[float] = []
|
| 120 |
+
for i, target_ids in enumerate(cand_token_ids):
|
| 121 |
+
token_ids = inputs.input_ids[i]
|
| 122 |
+
mask_positions = (token_ids == mask_token_id).nonzero(as_tuple=True)[0]
|
| 123 |
+
|
| 124 |
+
if mask_positions.numel() == 0 or not target_ids:
|
| 125 |
+
scores.append(-100.0)
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
n = min(len(target_ids), mask_positions.numel())
|
| 129 |
+
total = 0.0
|
| 130 |
+
for j in range(n):
|
| 131 |
+
pos = mask_positions[j].item()
|
| 132 |
+
log_probs = torch.log_softmax(logits[i, pos, :], dim=0)
|
| 133 |
+
total += log_probs[target_ids[j]].item()
|
| 134 |
+
|
| 135 |
+
scores.append(total / n)
|
| 136 |
+
|
| 137 |
+
return scores
|
| 138 |
+
|
| 139 |
+
# ── Public decode ─────────────────────────────────────────────────────────
|
| 140 |
+
|
| 141 |
+
def decode(self, sentence: str) -> Tuple[str, List[str]]:
|
| 142 |
+
"""
|
| 143 |
+
Decode a Singlish sentence word-by-word using ByT5 + XLM-RoBERTa MLM.
|
| 144 |
+
Returns (transliterated_sentence, trace_logs).
|
| 145 |
+
"""
|
| 146 |
+
words = sentence.split()
|
| 147 |
+
if not words:
|
| 148 |
+
return "", []
|
| 149 |
+
|
| 150 |
+
# ── Phase 1: batch ByT5 candidate generation ──────────────────────────
|
| 151 |
+
# Collect only the words that need ByT5 (non-Sinhala), run in one pass
|
| 152 |
+
cores: List[str] = []
|
| 153 |
+
core_meta: List[tuple] = [] # (index_into_words, prefix, core, suffix, core_lower)
|
| 154 |
+
|
| 155 |
+
for i, raw in enumerate(words):
|
| 156 |
+
match = PUNCT_PATTERN.match(raw)
|
| 157 |
+
prefix, core, suffix = match.groups() if match else ("", raw, "")
|
| 158 |
+
if not _is_sinhala(core):
|
| 159 |
+
cores.append(core)
|
| 160 |
+
core_meta.append((i, prefix, core, suffix, core.lower()))
|
| 161 |
+
|
| 162 |
+
# Single ByT5 forward pass for all non-Sinhala words
|
| 163 |
+
byt5_results: List[List[str]] = (
|
| 164 |
+
self.transliterator.batch_candidates(cores, k=MAX_CANDIDATES)
|
| 165 |
+
if cores else []
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
byt5_map: dict = {} # word index → list of raw ByT5 strings
|
| 169 |
+
for (i, prefix, core, suffix, core_lower), cands in zip(core_meta, byt5_results):
|
| 170 |
+
byt5_map[i] = (prefix, suffix, core_lower, cands or [core])
|
| 171 |
+
|
| 172 |
+
word_infos: List[dict] = []
|
| 173 |
+
for i, raw in enumerate(words):
|
| 174 |
+
match = PUNCT_PATTERN.match(raw)
|
| 175 |
+
_, core, _ = match.groups() if match else ("", raw, "")
|
| 176 |
+
|
| 177 |
+
if _is_sinhala(core):
|
| 178 |
+
word_infos.append({"kind": "sinhala", "candidates": [raw]})
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
prefix, suffix, core_lower, byt5_cands = byt5_map[i]
|
| 182 |
+
sinhala_cands = [prefix + c + suffix for c in byt5_cands]
|
| 183 |
+
|
| 184 |
+
if core_lower in ENGLISH_VOCAB and len(core_lower) >= MIN_ENGLISH_LEN:
|
| 185 |
+
candidates = [raw] + [c for c in sinhala_cands if c != raw]
|
| 186 |
+
word_infos.append({"kind": "english", "candidates": candidates[:MAX_CANDIDATES + 1]})
|
| 187 |
+
else:
|
| 188 |
+
word_infos.append({"kind": "singlish", "candidates": sinhala_cands})
|
| 189 |
+
|
| 190 |
+
# ── Phase 2: greedy left-to-right pass (builds dynamic left context) ──
|
| 191 |
+
# Right context is seeded from first ByT5 candidate (pre-decode estimate)
|
| 192 |
+
stable_right = [info["candidates"][0] for info in word_infos]
|
| 193 |
+
selected_words: List[str] = []
|
| 194 |
+
|
| 195 |
+
for t, info in enumerate(word_infos):
|
| 196 |
+
candidates = info["candidates"]
|
| 197 |
+
left_ctx = " ".join(selected_words)
|
| 198 |
+
right_ctx = " ".join(stable_right[t + 1:])
|
| 199 |
+
raw_mlm = self._batch_mlm_score(
|
| 200 |
+
[left_ctx] * len(candidates),
|
| 201 |
+
[right_ctx] * len(candidates),
|
| 202 |
+
candidates,
|
| 203 |
+
)
|
| 204 |
+
norm_mlm = self._softmax_normalize(raw_mlm)
|
| 205 |
+
best = max(zip(candidates, norm_mlm), key=lambda x: x[1])
|
| 206 |
+
selected_words.append(best[0])
|
| 207 |
+
|
| 208 |
+
# ── Phase 3: re-score with full decoded sentence as context ───────────
|
| 209 |
+
# Right context is now the actual decoded output, not the pre-decode estimate
|
| 210 |
+
trace_logs: List[str] = []
|
| 211 |
+
final_words: List[str] = []
|
| 212 |
+
|
| 213 |
+
for t, info in enumerate(word_infos):
|
| 214 |
+
raw_word = words[t]
|
| 215 |
+
kind = info["kind"]
|
| 216 |
+
candidates = info["candidates"]
|
| 217 |
+
|
| 218 |
+
left_ctx = " ".join(final_words)
|
| 219 |
+
right_ctx = " ".join(selected_words[t + 1:])
|
| 220 |
+
|
| 221 |
+
raw_mlm = self._batch_mlm_score(
|
| 222 |
+
[left_ctx] * len(candidates),
|
| 223 |
+
[right_ctx] * len(candidates),
|
| 224 |
+
candidates,
|
| 225 |
+
)
|
| 226 |
+
norm_mlm = self._softmax_normalize(raw_mlm)
|
| 227 |
+
|
| 228 |
+
scored = sorted(
|
| 229 |
+
[ScoredCandidate(text=c, mlm_score=norm_mlm[i]) for i, c in enumerate(candidates)],
|
| 230 |
+
key=lambda x: x.mlm_score,
|
| 231 |
+
reverse=True,
|
| 232 |
+
)
|
| 233 |
+
best = scored[0]
|
| 234 |
+
final_words.append(best.text)
|
| 235 |
+
|
| 236 |
+
if kind == "sinhala":
|
| 237 |
+
trace_logs.append(
|
| 238 |
+
f"**Step {t+1}: `{raw_word}`** → `{best.text}` "
|
| 239 |
+
f"(Sinhala, MLM={best.mlm_score:.3f})\n"
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
trace_logs.append(
|
| 243 |
+
f"**Step {t+1}: `{raw_word}`** → `{best.text}` "
|
| 244 |
+
f"(MLM={best.mlm_score:.3f})\n"
|
| 245 |
+
+ "\n".join(f" - `{s.text}` {s.mlm_score:.3f}" for s in scored)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return " ".join(final_words), trace_logs
|
core/english.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
English vocabulary loader for SinCode v3.
|
| 3 |
+
Used for English passthrough detection in the decoder.
|
| 4 |
+
Loads purely from the 20k corpus file — no hardcoded word lists.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import requests
|
| 10 |
+
from typing import Set
|
| 11 |
+
|
| 12 |
+
from core.constants import ENGLISH_CORPUS_URL, MIN_ENGLISH_LEN
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _resolve_english_cache_path() -> str:
|
| 18 |
+
override = os.getenv("SINCODE_ENGLISH_CACHE")
|
| 19 |
+
if override:
|
| 20 |
+
return override
|
| 21 |
+
|
| 22 |
+
candidates = [
|
| 23 |
+
os.path.join(os.getenv("HF_HOME", ""), "english_20k.txt") if os.getenv("HF_HOME") else "",
|
| 24 |
+
os.path.join(os.getcwd(), "english_20k.txt"),
|
| 25 |
+
os.path.join(os.getenv("TMPDIR", os.getenv("TEMP", "/tmp")), "english_20k.txt"),
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
for path in candidates:
|
| 29 |
+
if not path:
|
| 30 |
+
continue
|
| 31 |
+
parent = os.path.dirname(path) or "."
|
| 32 |
+
try:
|
| 33 |
+
os.makedirs(parent, exist_ok=True)
|
| 34 |
+
with open(path, "a", encoding="utf-8"):
|
| 35 |
+
pass
|
| 36 |
+
return path
|
| 37 |
+
except OSError:
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
return "english_20k.txt"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
ENGLISH_CORPUS_CACHE = _resolve_english_cache_path()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_english_vocab() -> Set[str]:
|
| 47 |
+
vocab: Set[str] = set()
|
| 48 |
+
|
| 49 |
+
if not os.path.exists(ENGLISH_CORPUS_CACHE) or os.path.getsize(ENGLISH_CORPUS_CACHE) == 0:
|
| 50 |
+
try:
|
| 51 |
+
logger.info("Downloading English corpus...")
|
| 52 |
+
response = requests.get(ENGLISH_CORPUS_URL, timeout=10)
|
| 53 |
+
response.raise_for_status()
|
| 54 |
+
with open(ENGLISH_CORPUS_CACHE, "wb") as f:
|
| 55 |
+
f.write(response.content)
|
| 56 |
+
except (requests.RequestException, OSError) as exc:
|
| 57 |
+
logger.warning("Could not download English corpus: %s", exc)
|
| 58 |
+
return vocab
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
with open(ENGLISH_CORPUS_CACHE, "r", encoding="utf-8") as f:
|
| 62 |
+
vocab.update(
|
| 63 |
+
w for line in f
|
| 64 |
+
if (w := line.strip().lower()) and len(w) >= MIN_ENGLISH_LEN
|
| 65 |
+
)
|
| 66 |
+
except OSError as exc:
|
| 67 |
+
logger.warning("Could not read English corpus file: %s", exc)
|
| 68 |
+
|
| 69 |
+
logger.info("English vocabulary loaded: %d words", len(vocab))
|
| 70 |
+
return vocab
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
ENGLISH_VOCAB: Set[str] = load_english_vocab()
|
core/mappings.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
core/mappings.py — deprecated.
|
| 3 |
+
|
| 4 |
+
All manual Singlish→Sinhala mappings have been removed.
|
| 5 |
+
Correction pairs are in seq2seq/finetune_corrections.py and baked into
|
| 6 |
+
the ByT5 model weights via targeted correction fine-tuning.
|
| 7 |
+
Candidate generation is handled end-to-end by the ByT5 seq2seq model.
|
| 8 |
+
"""
|
english_20k.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
misc/dataset_110.csv
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,input,reference,split,has_code_mix,has_ambiguity,domain,notes
|
| 2 |
+
1,api kalin katha kala,අපි කලින් කතා කළා,test,0,0,general,pure singlish
|
| 3 |
+
2,eka honda wage thiyanawa,ඒක හොඳ වගේ තියෙනවා,test,0,1,general,wage=seems
|
| 4 |
+
3,meheta thadata wessa,මෙහෙට තදට වැස්සා,test,0,1,general,thadata=very
|
| 5 |
+
4,oya kiwwata mama giye,ඔයා කිව්වට මම ගියේ,test,0,0,general,contextual past
|
| 6 |
+
5,mama danne na eka gena,මම දන්නෙ නෑ ඒක ගැන,test,0,1,general,eka pronoun
|
| 7 |
+
6,oya awa wage na,ඔයා ආවා වගේ නෑ,test,0,1,general,wage=seems
|
| 8 |
+
7,ekat ynna bri,ඒකට යන්න බැරි,test,0,0,general,ad-hoc bri=bari
|
| 9 |
+
8,mama inne gedaradi,මම ඉන්නෙ ගෙදරදී,test,0,0,general,pure singlish
|
| 10 |
+
9,eka heta balamu,ඒක හෙට බලමු,test,0,0,general,eka pronoun
|
| 11 |
+
10,klya madi api passe yamu,කාලය මදි අපි පස්සෙ යමු,test,0,0,general,ad-hoc klya=kalaya
|
| 12 |
+
11,assignment eka ada submit karanna one,assignment එක අද submit කරන්න ඕනෙ,test,1,0,education,eka after English noun
|
| 13 |
+
12,exam hall eka nisa mama baya una,exam hall එක නිසා මම බය උනා,test,1,1,education,nisa=because
|
| 14 |
+
13,results blnna one,results බලන්න ඕනෙ,test,1,0,education,ad-hoc blnna=balanna
|
| 15 |
+
14,study group ekak hadamu,study group එකක් හදමු,test,1,0,education,ekak after English noun
|
| 16 |
+
15,viva ekta prepared wage na,viva එකට prepared වගේ නෑ,test,1,1,education,wage=seems
|
| 17 |
+
16,mta project ek submit krnna one,මට project එක submit කරන්න ඕනෙ,test,1,0,education,ad-hoc mta krnna
|
| 18 |
+
17,hta parikshanaya thiyanawa,හෙට පරික්ෂණය තියෙනවා,test,0,0,education,ad-hoc hta=heta
|
| 19 |
+
18,mama potha kiyawala iwara kala,මම පොත කියවලා ඉවර කළා,test,0,0,education,pure singlish
|
| 20 |
+
19,prkku nisa api kalin giya,පරක්කු නිසා අපි කලින් ගියා,test,0,1,education,nisa=because
|
| 21 |
+
20,prashnaya hondai wage penenawa,ප්රශ්නය හොඳයි වගේ පේනවා,test,0,1,education,wage=seems
|
| 22 |
+
21,deployments nisa site down wuna,deployments නිසා site down උනා,test,1,1,work,nisa=because
|
| 23 |
+
22,PR eka merge karanna one,PR එක merge කරන්න ඕනෙ,test,1,0,work,eka after English noun
|
| 24 |
+
23,backlog eka update kala,backlog එක update කළා,test,1,0,work,eka after English noun
|
| 25 |
+
24,server down nisa work karanna ba,server down නිසා work කරන්න බෑ,test,1,1,work,nisa=because
|
| 26 |
+
25,meeting eka tomorrow damu,meeting එක tomorrow දාමු,test,1,0,work,code-mix preserved
|
| 27 |
+
26,feedback nisa redo karanna una,feedback නිසා redo කරන්න උනා,test,1,1,work,nisa=because
|
| 28 |
+
27,ape wada ada iwara wenawa,අපේ වැඩ අද ඉවර වෙනවා,test,0,0,work,pure singlish
|
| 29 |
+
28,kalamanakaru hitpu nisa api katha kala,කලමනාකරු හිටපු නිසා අපි කතා කළා,test,0,1,work,nisa=because; known failure (complex OOV)
|
| 30 |
+
29,me wada hondai wage penawa,මේ වැඩ හොඳයි වගේ පේනවා,test,0,1,work,wage=seems
|
| 31 |
+
30,wada tika ada iwara karamu,වැඩ ටික අද ඉවර කරමු,test,0,0,work,pure singlish
|
| 32 |
+
31,story eke poll ekak damma,story එකේ poll එකක් දැම්මා,test,1,0,social,eke and ekak forms
|
| 33 |
+
32,oyata DM ekak yawwa,ඔයාට DM එකක් යැව්වා,test,1,0,social,ekak after English noun
|
| 34 |
+
33,comment eka delete kala nisa mama danne na,comment එක delete කළා නිසා මම දන්නෙ නෑ,test,1,1,social,"nisa=because; known failure (කළා/කල, දන්නෙ/දන්නේ)"
|
| 35 |
+
34,selfie ekak gannako,selfie එකක් ගන්නකෝ,test,1,0,social,ekak after English noun
|
| 36 |
+
35,post eka private nisa share karanna epa,post එක private නිසා share කරන්න එපා,test,1,1,social,nisa=because
|
| 37 |
+
36,oyta message krnna one,ඔයාට message කරන්න ඕනෙ,test,1,0,social,ad-hoc oyta krnna on=one
|
| 38 |
+
37,api passe katha karamu,අපි පස්සෙ කතා කරමු,test,0,0,social,pure singlish
|
| 39 |
+
38,eya laga pinthurayk thiyanawa,ඒයා ළඟ පින්තූරයක් තියෙනවා,test,0,0,social,ad-hoc pinthurayk
|
| 40 |
+
39,oya awa wage mata hithenawa,ඔයා ආවා වගේ මට හිතෙනවා,test,0,1,social,wage=seems
|
| 41 |
+
40,api passe hambawemu,අපි පස්සෙ හම්බවෙමු,test,0,0,social,pure singlish
|
| 42 |
+
41,phone eka charge karanna one,phone එක charge කරන්න ඕනෙ,test,1,0,general,NEW: general code-mix (gap fix)
|
| 43 |
+
42,bus eka late una,bus එක late උනා,test,1,0,general,NEW: general code-mix
|
| 44 |
+
43,mama online inne,මම online ඉන්නෙ,test,1,0,general,NEW: English mid-sentence
|
| 45 |
+
44,time nathi nisa heta yamu,time නැති නිසා හෙට යමු,test,1,1,general,NEW: English+nisa in general
|
| 46 |
+
45,oya call eka ganna,ඔයා call එක ගන්න,test,1,0,general,NEW: general code-mix eka pattern
|
| 47 |
+
46,api game yanawa heta,අපි ගමේ යනවා හෙට,test,0,1,general,NEW: game=ගමේ(village) ambig with English 'game'
|
| 48 |
+
47,man heta enne na,මන් හෙට එන්නෙ නෑ,test,0,1,general,NEW: man=මං(I) ambig with English 'man'
|
| 49 |
+
48,eka hari lassanai,ඒක හරි ලස්සනයි,test,0,1,general,NEW: hari=very (not OK/correct)
|
| 50 |
+
49,oya kiwwa hari,ඔයා කිව්වා හරි,test,0,1,general,NEW: hari=correct (not very)
|
| 51 |
+
50,kalaya ithuru krganna one,කලය ඉතුරු කරගන්න ඕනෙ,test,0,1,general,NEW: one=ඕනෙ(need) ambig with English 'one'
|
| 52 |
+
51,date eka fix karanna one,date එක fix කරන්න ඕනෙ,test,1,1,general,NEW: date=English preserve; one=ඕනෙ
|
| 53 |
+
52,rata yanna one,රට යන්න ඕනෙ,test,0,0,general,"NEW: rata=country, pure singlish"
|
| 54 |
+
53,game eke leaderboard eka balanna,game එකේ leaderboard එක බලන්න,test,1,1,social,NEW: game=English(video game) not ගමේ
|
| 55 |
+
54,api thamai hodama,අපි තමයි හොඳම,test,0,1,general,NEW: thamai=emphatic we; hodama=best; looks English but Singlish
|
| 56 |
+
55,mama heta udee enawa oya enakota message ekk dnna,මම හෙට උදේ එනවා ඔයා එනකොට message එකක් දාන්න,test,0,0,general,NEW: 8-word pure singlish
|
| 57 |
+
56,ape gedara langa thiyana kadeta yanna one,අපේ ගෙදර ළඟ තියෙන කඩේට යන්න ඕනෙ,test,0,0,general,NEW: 7-word with ළඟ
|
| 58 |
+
57,mama assignment eka karala submit karanawa ada raa,මම assignment එක කරලා submit කරනවා අද රෑ,test,1,0,education,NEW: 8-word code-mix long
|
| 59 |
+
58,oya enne naththe mokada kiyla mama danne na,ඔයා එන්නෙ නැත්තෙ මොකද කියලා මම දන්නෙ නෑ,test,0,0,general,NEW: 9-word complex clause
|
| 60 |
+
59,client ekka call karala feedback eka ahanna one,client එක්ක call කරලා feedback එක අහන්න ඕනෙ,test,1,0,work,NEW: 8-word heavy code-mix
|
| 61 |
+
60,mama gedara gihilla kewata passe call karannm,මම ගෙදර ගිහිල්ලා කෑවට පස්සෙ call කරන්නම්,test,1,0,general,NEW: 8-word code-mix + temporal
|
| 62 |
+
61,laptop eke software update karanna one,laptop එකේ software update කරන්න ඕනෙ,test,1,0,work,NEW: 3 English words consecutive
|
| 63 |
+
62,office eke wifi password eka mokakda,office එකේ wifi password එක මොකක්ද,test,1,0,work,NEW: 3 English words; question
|
| 64 |
+
63,online order eka track karanna ba,online order එක track කරන්න බෑ,test,1,0,general,NEW: 3 English words
|
| 65 |
+
64,email eke attachment eka download karanna,email එකේ attachment එක download කරන්න,test,1,0,work,NEW: 3 English words + double eka
|
| 66 |
+
65,Instagram story eke filter eka hadanna,Instagram story එකේ filter එක හදන්න,test,1,0,social,NEW: 4 English words; social media
|
| 67 |
+
66,oyge wada iwra krd,ඔයාගෙ වැඩ ඉවර කරාද,test,0,0,general,NEW: extreme vowel omission
|
| 68 |
+
67,mge phone ek hack una,මගේ phone එක hack උනා,test,1,0,general,"NEW: heavy ad-hoc mmge=mage, hrk=hack"
|
| 69 |
+
68,handawata ynna wenwa,හැන්දෑවට යන්න වෙනවා,test,0,0,general,"NEW: ad-hoc hndta=handeta, wenwa=wenawa"
|
| 70 |
+
69,prashnya krnna oni,ප්රශ්නය කරන්න ඕනි,test,0,0,education,NEW: replaced extreme ad-hoc with more readable form
|
| 71 |
+
70,apita gdra ynna oni,අපිට ගෙදර යන්න ඕනි,test,0,0,general,NEW: ad-hoc gdra=gedara
|
| 72 |
+
71,mama oyata kiwwa,මම ඔයාට කිව්වා,test,0,0,general,"NEW: common words only (mama, oyata)"
|
| 73 |
+
72,oya hari hondai,ඔයා හරි හොඳයි,test,0,1,general,NEW: hari=very; common words
|
| 74 |
+
73,api heta yamu,අපි හෙට යමු,test,0,0,general,NEW: common words bypass test
|
| 75 |
+
74,app eka crash wenawa phone eke,app එක crash වෙනවා phone එකේ,test,1,0,technology,NEW: tech domain
|
| 76 |
+
75,code eka push karanna github ekata,code එක push කරන්න github එකට,test,1,0,technology,NEW: dev workflow code-mix
|
| 77 |
+
76,database eka slow nisa query eka optimize karanna one,database එක slow නිසා query එක optimize කරන්න ඕනෙ,test,1,1,technology,NEW: heavy tech code-mix + nisa; long
|
| 78 |
+
77,bug eka fix kala merge karanna,bug එක fix කළා merge කරන්න,test,1,0,technology,NEW: sequential actions code-mix
|
| 79 |
+
78,internet eka slow wage thiyanawa,internet එක slow වගේ තියෙනවා,test,1,1,technology,NEW: tech + wage ambiguity
|
| 80 |
+
79,kema hodai ada,කෑම හොඳයි අද,test,0,0,daily_life,NEW: daily life; short
|
| 81 |
+
80,mama bus eke enawa,මම bus එකේ එනවා,test,1,0,daily_life,NEW: transport code-mix
|
| 82 |
+
81,ganu depala ekka market giya,ගෑනු දෙපල එක්ක market ගියා,test,1,0,daily_life,NEW: colloquial + code-mix
|
| 83 |
+
82,watura bonna one,වතුර බොන්න ඕනෙ,test,0,0,daily_life,NEW: health advice singlish
|
| 84 |
+
83,shop eke sugar nati nisa mama giye na,shop එකේ sugar නැති නිසා මම ගියේ නෑ,test,1,1,daily_life,NEW: daily code-mix + nisa; negative
|
| 85 |
+
84,hri hari,හරි හරි,test,0,0,general,NEW: 2-word repetition; common expression + ad-hoc hri=hari
|
| 86 |
+
85,mta ep,මට එපා,test,0,0,general,NEW: ad-hoc mta=mata ep=epa
|
| 87 |
+
86,ok hari,ok හරි,test,1,0,general,NEW: 2-word code-mix
|
| 88 |
+
87,ape game hari dewal wenne,අපේ ගමේ හරි දේවල් වෙන්නේ,test,0,1,general,"NEW: game=village, hari=nice; looks English"
|
| 89 |
+
88,mta dan one na,මට දැන් ඕනෙ නෑ,test,0,1,general,NEW: man+one look English but Singlish
|
| 90 |
+
89,eka hari hondai wage dnuna nisa mama giya,ඒක හරි හොඳයි වගේ දැනුනා නිසා මම ගියා,test,0,1,general,NEW: hari+wage+nisa triple ambiguity; ref corrected to හොඳයි
|
| 91 |
+
90,game eke mission hari amarui,game එකේ mission හරි අමාරුයි,test,0,1,general,NEW: game=video game hari=very amarui=difficult; looks English but Singlish
|
| 92 |
+
91,mama heta yanawa,මම හෙට යනවා,test,0,0,general,NEW: future tense
|
| 93 |
+
92,ey iye aawa,එයා ඊයේ ආවා,test,0,0,general,NEW: past tense
|
| 94 |
+
93,api dan yanawa,අපි දැන් යනවා,test,0,0,general,NEW: present tense
|
| 95 |
+
94,video eka balanna one,video එක බලන්න ඕනෙ,test,1,0,social,NEW: eka definite article
|
| 96 |
+
95,video ekak hadamu,video එකක් හදමු,test,1,0,social,NEW: ekak indefinite
|
| 97 |
+
96,video eke comment eka balanna,video එකේ comment එක බලන්න,test,1,0,social,NEW: eke possessive + double eka
|
| 98 |
+
97,video ekata like ekak danna,video එකට like එකක් දාන්න,test,1,0,social,NEW: ekata dative case
|
| 99 |
+
98,lecture eka record karala share karanna,lecture එක record කරලා share කරන්න,test,1,0,education,NEW: sequential code-mix actions
|
| 100 |
+
99,research paper eka liyanna one heta wge,research paper එක ලියන්න ඕනෙ හෙට වගේ,test,1,0,education,NEW: long + temporal; 8 words
|
| 101 |
+
100,exam eka hari amarui,exam එක හරි අමාරුයි,test,1,1,education,NEW: hari=very; difficulty context
|
| 102 |
+
101,sprint eka plan karamu Monday,sprint එක plan කරමු Monday,test,1,0,work,NEW: day name preserved
|
| 103 |
+
102,ape team eka deadline ekata kala,අපේ team එක deadline එකට කළා,test,1,0,work,NEW: possessive + double English
|
| 104 |
+
103,standup eke mokada kiwwe,standup එකේ මොකද කිව්වෙ,test,1,0,work,NEW: question form code-mix
|
| 105 |
+
104,reel eka viral una,reel එක viral උනා,test,1,0,social,NEW: social media terminology
|
| 106 |
+
105,group chat eke mokada wenne,group chat එකේ මොකද වෙන්නෙ,test,1,0,social,NEW: compound English + question
|
| 107 |
+
106,oyge profile picture eka lassanai,ඔයාගෙ profile picture එක ලස්සනයි,test,1,0,social,NEW: compound English noun + eka; ref corrected to ඔයාගෙ
|
| 108 |
+
107,mama enne na heta,මම එන්නෙ නෑ හෙට,test,0,0,general,NEW: negation at end
|
| 109 |
+
108,eka karanna epa,ඒක කරන්න එපා,test,0,0,general,NEW: prohibition form
|
| 110 |
+
109,kawruwath enne na,කවුරුවත් එන්නෙ නෑ,test,0,0,general,NEW: nobody negation
|
| 111 |
+
110,oya koheda ynne,ඔයා කොහේද යන්නේ,test,0,0,general,NEW: question form where
|
misc/dataset_40.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,input,reference,split,has_code_mix,has_ambiguity,domain,notes
|
| 2 |
+
1,api kalin katha kala,අපි කලින් කතා කළා,train,0,0,general,pure singlish
|
| 3 |
+
2,eka honda wage thiyanawa,ඒක හොඳ වගේ තියෙනවා,train,0,1,general,wage=seems
|
| 4 |
+
3,pola nisa gedara thiyanawa,පොල නිසා ගෙදර තියෙනවා,train,0,1,general,nisa=because
|
| 5 |
+
4,oya kiwwata mama giye,ඔයා කිව්වට මම ගියේ,train,0,0,general,contextual past
|
| 6 |
+
5,mama danne na eka gena,මම දන්නෙ නෑ ඒක ගැන,train,0,1,general,eka pronoun
|
| 7 |
+
6,oya awa wage na,ඔයා ආවා වගේ නෑ,train,0,1,general,wage=seems
|
| 8 |
+
7,ekat ynna bri,ඒකට යන්න බැරි,train,0,0,general,ad hoc bri=bari
|
| 9 |
+
8,mama inne gedaradi,මම ඉන්නෙ ගෙදරදී,train,0,0,general,pure singlish
|
| 10 |
+
9,eka heta balamu,ඒක හෙට බලමු,train,0,0,general,eka pronoun
|
| 11 |
+
10,klya madi api passe yamu,කාලය මදි අපි පස්සෙ යමු,train,0,0,general,ad hoc klya=kalaya
|
| 12 |
+
11,assignment eka ada submit karanna one,assignment එක අද submit කරන්න ඕනෙ,train,1,0,education,eka after English noun
|
| 13 |
+
12,exam hall eka nisa mama baya una,exam hall එක නිසා මම බය උනා,train,1,1,education,nisa=because
|
| 14 |
+
13,results blnna one,results බලන්න ඕනෙ,train,1,0,education,ad hoc blnna=balanna
|
| 15 |
+
14,study group ekak hadamu,study group එකක් හදමු,train,1,0,education,ekak after English noun
|
| 16 |
+
15,viva ekta prepared wage na,viva එකට prepared වගේ නෑ,train,1,1,education,wage=seems
|
| 17 |
+
16,mta project ek submit krnna one,මට project එක submit කරන්න ඕනෙ,train,1,0,education,ad hoc mta krnna
|
| 18 |
+
17,hta parikshanaya thiyanawa,හෙට පරික්ෂණය තියෙනවා,train,0,0,education,ad hoc hta=heta
|
| 19 |
+
18,mama poth kiyawala iwara kala,මම පොත කියවලා ඉවර කළා,train,0,0,education,pure singlish
|
| 20 |
+
19,guruwaraya nisa api kalin giya,ගුරුවරයා නිසා අපි කලින් ගියා,train,0,1,education,nisa=because
|
| 21 |
+
20,prashnaya honda wage penenawa,ප්රශ්නය හොඳ වගේ පේනවා,train,0,1,education,wage=seems
|
| 22 |
+
21,deploy nisa site down wuna,deploy නිසා site down උනා,train,1,1,work,nisa=because
|
| 23 |
+
22,PR eka merge karanna one,PR එක merge කරන්න ඕනෙ,train,1,0,work,eka after English noun
|
| 24 |
+
23,backlog eka update kala,backlog එක update කළා,train,1,0,work,eka after English noun
|
| 25 |
+
24,server down nisa work karanna ba,server down නිසා work කරන්න බෑ,train,1,1,work,nisa=because
|
| 26 |
+
25,meeting eka tomorrow damu,meeting එක tomorrow දාමු,train,1,0,work,code mix preserved
|
| 27 |
+
26,feedback nisa redo karanna una,feedback නිසා redo කරන්න උනා,train,1,1,work,nisa=because
|
| 28 |
+
27,ape wada ada iwara wenawa,අපේ වැඩ අද ඉවර වෙනවා,train,0,0,work,pure singlish
|
| 29 |
+
28,kalamanakaru apu nisa api katha kala,කලමණාකරු ආපු නිසා අපි කතා කලා,train,0,1,work,nisa=because
|
| 30 |
+
29,me wada honda wage penenawa,මේ වැඩ හොඳ වගේ පේනවා,train,0,1,work,wage=seems
|
| 31 |
+
30,wada tika ada iwara karamu,වැඩ ටික අද ඉවර කරමු,train,0,0,work,pure singlish
|
| 32 |
+
31,story eke poll ekak damma,story එකේ poll එකක් දැම්මා,train,1,0,social,eke and ekak forms
|
| 33 |
+
32,oyata DM ekak yewwa,ඔයාට DM එකක් යැව්වා,train,1,0,social,ekak after English noun
|
| 34 |
+
33,comment eka delete kala nisa mama danne na,comment එක delete කල නිසා මම දන්නේ නෑ,train,1,1,social,nisa=because
|
| 35 |
+
34,selfie ekak gannako,selfie එකක් ගන්නකෝ,train,1,0,social,ekak after English noun
|
| 36 |
+
35,post eka private nisa share karanna epa,post එක private නිසා share කරන්න එපා,train,1,1,social,nisa=because
|
| 37 |
+
36,oyta message krnna on,ඔයාට message කරන්න ඕනෙ,train,1,0,social,ad hoc oyta krnna
|
| 38 |
+
37,oya passe katha karamu,ඔයා පස්සෙ කතා කරමු,train,0,0,social,pure singlish
|
| 39 |
+
38,eya laga pinthurayk thiyanawa,ඒයා ළඟ පින්තූරයක් තියෙනවා,train,0,0,social,ad hoc pinthurayk
|
| 40 |
+
39,oya awa wage mata hithenawa,ඔයා ආවා වගේ මට හිතෙනවා,train,0,1,social,wage=seems
|
| 41 |
+
40,api passe hambawemu,අපි පස්සෙ හම්බවෙමු,train,0,0,social,pure singlish
|
misc/evaluate.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SinCode v3 — Evaluation Script
|
| 3 |
+
|
| 4 |
+
Supports two evaluation modes selected via --mode:
|
| 5 |
+
|
| 6 |
+
system Full v3 pipeline (ByT5 + two-pass MLM). Default.
|
| 7 |
+
ablation Side-by-side comparison of two configurations:
|
| 8 |
+
(A) ByT5 top-1 only — no MLM reranking
|
| 9 |
+
(B) ByT5 + MLM — full Code-Mixed pipeline
|
| 10 |
+
Proves the contribution of the XLM-RoBERTa reranker.
|
| 11 |
+
|
| 12 |
+
Note: mBart50 is intentionally excluded from evaluation here because the
|
| 13 |
+
reference dataset uses code-mixed targets (English words preserved). mBart50
|
| 14 |
+
produces full-Sinhala output by design, making a metric comparison against
|
| 15 |
+
code-mixed references invalid. Evaluate mBart50 separately with a dataset
|
| 16 |
+
whose references are fully in Sinhala script.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python misc/evaluate.py --dataset misc/dataset_110.csv
|
| 20 |
+
python misc/evaluate.py --dataset misc/dataset_110.csv --mode ablation
|
| 21 |
+
python misc/evaluate.py --dataset misc/dataset_110.csv --mode ablation --out misc/results.csv
|
| 22 |
+
|
| 23 |
+
CSV columns required: id, input, reference
|
| 24 |
+
Optional columns (used for grouping): category, domain, has_code_mix, has_ambiguity
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import csv
|
| 31 |
+
import json
|
| 32 |
+
import logging
|
| 33 |
+
import math
|
| 34 |
+
import os
|
| 35 |
+
import sys
|
| 36 |
+
from collections import defaultdict
|
| 37 |
+
from dataclasses import dataclass
|
| 38 |
+
from typing import Dict, List, Optional
|
| 39 |
+
|
| 40 |
+
# ── Path setup ────────────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 43 |
+
if ROOT not in sys.path:
|
| 44 |
+
sys.path.insert(0, ROOT)
|
| 45 |
+
|
| 46 |
+
logging.basicConfig(level=logging.WARNING)
|
| 47 |
+
|
| 48 |
+
# ── Metrics ───────────────────────────────────────────────────────────────────
|
| 49 |
+
|
| 50 |
+
def _levenshtein(a: str, b: str) -> int:
|
| 51 |
+
if not a: return len(b)
|
| 52 |
+
if not b: return len(a)
|
| 53 |
+
prev = list(range(len(b) + 1))
|
| 54 |
+
for i, ca in enumerate(a, 1):
|
| 55 |
+
curr = [i] + [0] * len(b)
|
| 56 |
+
for j, cb in enumerate(b, 1):
|
| 57 |
+
cost = 0 if ca == cb else 1
|
| 58 |
+
curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + cost)
|
| 59 |
+
prev = curr
|
| 60 |
+
return prev[-1]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _levenshtein_tokens(a: list, b: list) -> int:
|
| 64 |
+
if not a: return len(b)
|
| 65 |
+
if not b: return len(a)
|
| 66 |
+
prev = list(range(len(b) + 1))
|
| 67 |
+
for i, ta in enumerate(a, 1):
|
| 68 |
+
curr = [i] + [0] * len(b)
|
| 69 |
+
for j, tb in enumerate(b, 1):
|
| 70 |
+
cost = 0 if ta == tb else 1
|
| 71 |
+
curr[j] = min(prev[j] + 1, curr[j-1] + 1, prev[j-1] + cost)
|
| 72 |
+
prev = curr
|
| 73 |
+
return prev[-1]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def cer(pred: str, ref: str) -> float:
|
| 77 |
+
if not ref: return 0.0 if not pred else 1.0
|
| 78 |
+
return _levenshtein(pred, ref) / max(len(ref), 1)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def wer(pred: str, ref: str) -> float:
|
| 82 |
+
pt, rt = pred.split(), ref.split()
|
| 83 |
+
if not rt: return 0.0 if not pt else 1.0
|
| 84 |
+
return _levenshtein_tokens(pt, rt) / max(len(rt), 1)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def token_accuracy(pred: str, ref: str) -> float:
|
| 88 |
+
pt, rt = pred.split(), ref.split()
|
| 89 |
+
if not rt: return 0.0 if pt else 1.0
|
| 90 |
+
return sum(p == r for p, r in zip(pt, rt)) / max(len(rt), 1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def bleu(pred: str, ref: str, max_n: int = 4) -> float:
|
| 94 |
+
from collections import Counter
|
| 95 |
+
pt, rt = pred.split(), ref.split()
|
| 96 |
+
if not pt or not rt: return 0.0
|
| 97 |
+
n_max = min(max_n, len(pt), len(rt))
|
| 98 |
+
if n_max == 0: return 0.0
|
| 99 |
+
brevity = min(1.0, len(pt) / len(rt))
|
| 100 |
+
log_avg = 0.0
|
| 101 |
+
for n in range(1, n_max + 1):
|
| 102 |
+
pc = Counter(tuple(pt[i:i+n]) for i in range(len(pt)-n+1))
|
| 103 |
+
rc = Counter(tuple(rt[i:i+n]) for i in range(len(rt)-n+1))
|
| 104 |
+
clipped = sum(min(c, rc[ng]) for ng, c in pc.items())
|
| 105 |
+
total = max(sum(pc.values()), 1)
|
| 106 |
+
prec = clipped / total
|
| 107 |
+
if prec == 0: return 0.0
|
| 108 |
+
log_avg += math.log(prec) / n_max
|
| 109 |
+
return brevity * math.exp(log_avg)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def exact_match(pred: str, ref: str) -> float:
|
| 113 |
+
return 1.0 if pred.strip() == ref.strip() else 0.0
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ── Data model ────────────────────────────────────────────────────────────────
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class TestCase:
|
| 120 |
+
id: int
|
| 121 |
+
input: str
|
| 122 |
+
reference: str
|
| 123 |
+
domain: str = "general"
|
| 124 |
+
has_code_mix: bool = False
|
| 125 |
+
has_ambiguity: bool = False
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class Result:
|
| 130 |
+
test_case: TestCase
|
| 131 |
+
system: str
|
| 132 |
+
prediction: str
|
| 133 |
+
cer_score: float
|
| 134 |
+
wer_score: float
|
| 135 |
+
token_acc: float
|
| 136 |
+
bleu_score: float
|
| 137 |
+
exact: float
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _score(tc: TestCase, pred: str, system: str) -> Result:
|
| 141 |
+
return Result(
|
| 142 |
+
test_case=tc,
|
| 143 |
+
system=system,
|
| 144 |
+
prediction=pred,
|
| 145 |
+
cer_score=cer(pred, tc.reference),
|
| 146 |
+
wer_score=wer(pred, tc.reference),
|
| 147 |
+
token_acc=token_accuracy(pred, tc.reference),
|
| 148 |
+
bleu_score=bleu(pred, tc.reference),
|
| 149 |
+
exact=exact_match(pred, tc.reference),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ── Test set loader ───────────────────────────────────────────────────────────
|
| 154 |
+
|
| 155 |
+
def load_dataset(csv_path: str) -> List[TestCase]:
|
| 156 |
+
cases = []
|
| 157 |
+
with open(csv_path, "r", encoding="utf-8", newline="") as f:
|
| 158 |
+
reader = csv.DictReader(f)
|
| 159 |
+
fields = set(reader.fieldnames or [])
|
| 160 |
+
if not {"input", "reference"}.issubset(fields):
|
| 161 |
+
raise ValueError(f"CSV must have 'input' and 'reference' columns. Found: {fields}")
|
| 162 |
+
for row in reader:
|
| 163 |
+
inp = (row.get("input") or "").strip().replace("\n", " ")
|
| 164 |
+
ref = (row.get("reference") or "").strip().replace("\n", " ")
|
| 165 |
+
if not inp or not ref:
|
| 166 |
+
continue
|
| 167 |
+
cases.append(TestCase(
|
| 168 |
+
id=int(row.get("id") or 0),
|
| 169 |
+
input=inp,
|
| 170 |
+
reference=ref,
|
| 171 |
+
domain=(row.get("domain") or row.get("category") or "general").strip(),
|
| 172 |
+
has_code_mix=bool(int(row.get("has_code_mix") or 0)),
|
| 173 |
+
has_ambiguity=bool(int(row.get("has_ambiguity") or 0)),
|
| 174 |
+
))
|
| 175 |
+
return cases
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ── Model loaders ─────────────────────────────────────────────────────────────
|
| 179 |
+
|
| 180 |
+
def _load_v3_decoder():
|
| 181 |
+
from sincode_model import BeamSearchDecoder
|
| 182 |
+
print(" Loading ByT5 + XLM-RoBERTa (Code-Mixed pipeline)...")
|
| 183 |
+
return BeamSearchDecoder()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _byt5_top1_predict(decoder, sentence: str) -> str:
|
| 187 |
+
"""ByT5 top-1 only — pick first beam candidate, skip MLM reranking."""
|
| 188 |
+
from core.constants import PUNCT_PATTERN
|
| 189 |
+
from core.decoder import _is_sinhala
|
| 190 |
+
|
| 191 |
+
words = sentence.split()
|
| 192 |
+
output = []
|
| 193 |
+
cores = [re.sub(r"^\W*|\W*$", "", w) for w in words]
|
| 194 |
+
non_sinhala = [c for c in cores if not _is_sinhala(c) and c]
|
| 195 |
+
|
| 196 |
+
if not non_sinhala:
|
| 197 |
+
return sentence
|
| 198 |
+
|
| 199 |
+
byt5_results = decoder.transliterator.batch_candidates(non_sinhala, k=1)
|
| 200 |
+
byt5_iter = iter(byt5_results)
|
| 201 |
+
|
| 202 |
+
for raw, core in zip(words, cores):
|
| 203 |
+
m = PUNCT_PATTERN.match(raw)
|
| 204 |
+
prefix, _, suffix = m.groups() if m else ("", raw, "")
|
| 205 |
+
if _is_sinhala(core) or not core:
|
| 206 |
+
output.append(raw)
|
| 207 |
+
else:
|
| 208 |
+
cands = next(byt5_iter, [core])
|
| 209 |
+
output.append(prefix + (cands[0] if cands else core) + suffix)
|
| 210 |
+
return " ".join(output)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ── Reporting ─────────────────────────────────────────────────────────────────
|
| 214 |
+
|
| 215 |
+
def _avg(vals: List[float]) -> float:
|
| 216 |
+
return sum(vals) / len(vals) if vals else 0.0
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _print_table(label: str, results: List[Result]):
|
| 220 |
+
print(f"\n{'='*74}")
|
| 221 |
+
print(f" {label} (n={len(results)})")
|
| 222 |
+
print(f"{'='*74}")
|
| 223 |
+
print(f" {'ID':<5} {'Domain':<14} {'CM':>3} {'Am':>3} {'CER':>6} {'WER':>6} {'TokAcc':>7} {'BLEU':>6} {'EM':>4}")
|
| 224 |
+
print(f" {'-'*66}")
|
| 225 |
+
for r in results:
|
| 226 |
+
tc = r.test_case
|
| 227 |
+
print(
|
| 228 |
+
f" {tc.id:<5} {tc.domain[:13]:<14} {'Y' if tc.has_code_mix else 'N':>3} "
|
| 229 |
+
f"{'Y' if tc.has_ambiguity else 'N':>3} "
|
| 230 |
+
f"{r.cer_score:>6.3f} {r.wer_score:>6.3f} {r.token_acc:>7.3f} "
|
| 231 |
+
f"{r.bleu_score:>6.3f} {r.exact:>4.0f}"
|
| 232 |
+
)
|
| 233 |
+
print(f" {'-'*66}")
|
| 234 |
+
print(
|
| 235 |
+
f" {'AVERAGE':<26} "
|
| 236 |
+
f"{_avg([r.cer_score for r in results]):>6.3f} "
|
| 237 |
+
f"{_avg([r.wer_score for r in results]):>6.3f} "
|
| 238 |
+
f"{_avg([r.token_acc for r in results]):>7.3f} "
|
| 239 |
+
f"{_avg([r.bleu_score for r in results]):>6.3f} "
|
| 240 |
+
f"{_avg([r.exact for r in results]):>4.2f}"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Per-domain breakdown
|
| 244 |
+
by_domain: Dict[str, List[Result]] = defaultdict(list)
|
| 245 |
+
for r in results:
|
| 246 |
+
by_domain[r.test_case.domain].append(r)
|
| 247 |
+
if len(by_domain) > 1:
|
| 248 |
+
print(f"\n Per-domain averages (CER / WER / TokAcc):")
|
| 249 |
+
for dom, rs in sorted(by_domain.items()):
|
| 250 |
+
print(
|
| 251 |
+
f" {dom:<18} n={len(rs):<4} "
|
| 252 |
+
f"CER={_avg([r.cer_score for r in rs]):.3f} "
|
| 253 |
+
f"WER={_avg([r.wer_score for r in rs]):.3f} "
|
| 254 |
+
f"TokAcc={_avg([r.token_acc for r in rs]):.3f}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Code-mixed vs pure Singlish
|
| 258 |
+
cm_r = [r for r in results if r.test_case.has_code_mix]
|
| 259 |
+
pure_r = [r for r in results if not r.test_case.has_code_mix]
|
| 260 |
+
if cm_r and pure_r:
|
| 261 |
+
print(
|
| 262 |
+
f"\n Code-mixed (n={len(cm_r):<3}): "
|
| 263 |
+
f"CER={_avg([r.cer_score for r in cm_r]):.3f} "
|
| 264 |
+
f"WER={_avg([r.wer_score for r in cm_r]):.3f}"
|
| 265 |
+
)
|
| 266 |
+
print(
|
| 267 |
+
f" Pure Singlish (n={len(pure_r):<3}): "
|
| 268 |
+
f"CER={_avg([r.cer_score for r in pure_r]):.3f} "
|
| 269 |
+
f"WER={_avg([r.wer_score for r in pure_r]):.3f}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _print_ablation(a_res: List[Result], b_res: List[Result]):
|
| 274 |
+
print(f"\n{'='*74}")
|
| 275 |
+
print(" ABLATION STUDY — MLM Reranking Contribution")
|
| 276 |
+
print(f" (A) ByT5 top-1 only | (B) ByT5 + XLM-RoBERTa MLM reranking")
|
| 277 |
+
print(f"{'='*74}")
|
| 278 |
+
print(f" {'Metric':<22} {'(A) ByT5-top1':>14} {'(B) ByT5+MLM':>13} {'Δ (B−A)':>10}")
|
| 279 |
+
print(f" {'-'*64}")
|
| 280 |
+
|
| 281 |
+
metrics = [
|
| 282 |
+
("CER (↓ better)", [r.cer_score for r in a_res], [r.cer_score for r in b_res], True),
|
| 283 |
+
("WER (↓ better)", [r.wer_score for r in a_res], [r.wer_score for r in b_res], True),
|
| 284 |
+
("Token Acc (↑)", [r.token_acc for r in a_res], [r.token_acc for r in b_res], False),
|
| 285 |
+
("BLEU (↑ better)", [r.bleu_score for r in a_res], [r.bleu_score for r in b_res], False),
|
| 286 |
+
("Exact Match (↑)", [r.exact for r in a_res], [r.exact for r in b_res], False),
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
for label, a_vals, b_vals, lower_is_better in metrics:
|
| 290 |
+
a_avg, b_avg = _avg(a_vals), _avg(b_vals)
|
| 291 |
+
delta = b_avg - a_avg
|
| 292 |
+
improved = (delta < 0) if lower_is_better else (delta > 0)
|
| 293 |
+
print(
|
| 294 |
+
f" {label:<22} {a_avg:>14.4f} {b_avg:>13.4f} "
|
| 295 |
+
f" {'✓' if improved else '✗'}{delta:>+8.4f}"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
print(f"\n ✓ B vs A isolates the contribution of XLM-RoBERTa MLM reranking.")
|
| 299 |
+
print(f" ✓ If B > A: the two-pass reranker justifies its computational cost.")
|
| 300 |
+
|
| 301 |
+
# Subcategory breakdown
|
| 302 |
+
for sublabel, filter_fn in [
|
| 303 |
+
("Code-mixed only", lambda r: r.test_case.has_code_mix),
|
| 304 |
+
("Ambiguous only", lambda r: r.test_case.has_ambiguity),
|
| 305 |
+
("Pure Singlish", lambda r: not r.test_case.has_code_mix),
|
| 306 |
+
]:
|
| 307 |
+
a_sub = [r for r in a_res if filter_fn(r)]
|
| 308 |
+
b_sub = [r for r in b_res if filter_fn(r)]
|
| 309 |
+
if not a_sub:
|
| 310 |
+
continue
|
| 311 |
+
print(f"\n {sublabel} (n={len(a_sub)}):")
|
| 312 |
+
print(f" {'':20} {'(A)':>10} {'(B)':>10} {'Δ':>10}")
|
| 313 |
+
for ml, getter, low in [("CER", lambda r: r.cer_score, True), ("WER", lambda r: r.wer_score, True), ("TokAcc", lambda r: r.token_acc, False)]:
|
| 314 |
+
av, bv = _avg([getter(r) for r in a_sub]), _avg([getter(r) for r in b_sub])
|
| 315 |
+
d = bv - av
|
| 316 |
+
imp = (d < 0) if low else (d > 0)
|
| 317 |
+
print(
|
| 318 |
+
f" {ml:<20} {av:>10.4f} {bv:>10.4f} "
|
| 319 |
+
f" {'✓' if imp else '✗'}{d:>+7.4f}"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _load_baseline(path: str) -> dict:
|
| 324 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 325 |
+
return json.load(f)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def _print_v2_comparison(b_res: List[Result], baseline: dict):
|
| 329 |
+
n = len(b_res)
|
| 330 |
+
v3 = {
|
| 331 |
+
"exact_match": _avg([r.exact for r in b_res]),
|
| 332 |
+
"cer": _avg([r.cer_score for r in b_res]),
|
| 333 |
+
"wer": _avg([r.wer_score for r in b_res]),
|
| 334 |
+
"bleu": _avg([r.bleu_score for r in b_res]),
|
| 335 |
+
"token_acc": _avg([r.token_acc for r in b_res]),
|
| 336 |
+
}
|
| 337 |
+
v2_label = baseline.get("system", "v2 baseline")
|
| 338 |
+
|
| 339 |
+
print(f"\n{'='*74}")
|
| 340 |
+
print(f" SinCode v2 vs SinCode v3 — Head-to-Head (n={n})")
|
| 341 |
+
print(f" v2: {v2_label}")
|
| 342 |
+
print(f" v3: ByT5-small seq2seq + XLM-RoBERTa MLM reranking")
|
| 343 |
+
print(f"{'='*74}")
|
| 344 |
+
print(f" {'Metric':<22} {'v2 (baseline)':>14} {'v3 (ours)':>10} {'Δ (v3−v2)':>12}")
|
| 345 |
+
print(f" {'-'*62}")
|
| 346 |
+
|
| 347 |
+
metrics = [
|
| 348 |
+
("Exact Match (↑)", "exact_match", False),
|
| 349 |
+
("CER (↓ better)", "cer", True),
|
| 350 |
+
("WER (↓ better)", "wer", True),
|
| 351 |
+
("BLEU (↑ better)", "bleu", False),
|
| 352 |
+
("Token Acc (↑)", "token_acc", False),
|
| 353 |
+
]
|
| 354 |
+
for label, key, lower_is_better in metrics:
|
| 355 |
+
v2v = baseline.get(key, 0.0)
|
| 356 |
+
v3v = v3[key]
|
| 357 |
+
delta = v3v - v2v
|
| 358 |
+
improved = (delta < 0) if lower_is_better else (delta > 0)
|
| 359 |
+
arrow = "↑" if (delta > 0) else ("↓" if delta < 0 else "=")
|
| 360 |
+
print(
|
| 361 |
+
f" {label:<22} {v2v:>14.4f} {v3v:>10.4f} "
|
| 362 |
+
f" {'✓' if improved else '✗'} {arrow}{abs(delta):>+8.4f}"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if baseline.get("notes"):
|
| 366 |
+
print(f"\n Note: {baseline['notes']}")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _save_csv(results_by_system: Dict[str, List[Result]], out_path: str):
|
| 370 |
+
rows = []
|
| 371 |
+
for system, results in results_by_system.items():
|
| 372 |
+
for r in results:
|
| 373 |
+
rows.append({
|
| 374 |
+
"system": system,
|
| 375 |
+
"id": r.test_case.id,
|
| 376 |
+
"domain": r.test_case.domain,
|
| 377 |
+
"has_code_mix": int(r.test_case.has_code_mix),
|
| 378 |
+
"has_ambiguity": int(r.test_case.has_ambiguity),
|
| 379 |
+
"input": r.test_case.input,
|
| 380 |
+
"reference": r.test_case.reference,
|
| 381 |
+
"prediction": r.prediction,
|
| 382 |
+
"cer": f"{r.cer_score:.4f}",
|
| 383 |
+
"wer": f"{r.wer_score:.4f}",
|
| 384 |
+
"token_acc": f"{r.token_acc:.4f}",
|
| 385 |
+
"bleu": f"{r.bleu_score:.4f}",
|
| 386 |
+
"exact_match": f"{r.exact:.0f}",
|
| 387 |
+
})
|
| 388 |
+
with open(out_path, "w", encoding="utf-8", newline="") as f:
|
| 389 |
+
w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 390 |
+
w.writeheader()
|
| 391 |
+
w.writerows(rows)
|
| 392 |
+
print(f"\n Results saved -> {out_path}")
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
| 396 |
+
|
| 397 |
+
def main():
|
| 398 |
+
parser = argparse.ArgumentParser(description="SinCode v3 evaluation")
|
| 399 |
+
parser.add_argument("--dataset", required=True,
|
| 400 |
+
help="Path to evaluation CSV (dataset_110.csv or dataset_40.csv)")
|
| 401 |
+
parser.add_argument("--mode", default="system",
|
| 402 |
+
choices=["system", "ablation"],
|
| 403 |
+
help="Evaluation mode (default: system)")
|
| 404 |
+
parser.add_argument("--out", default=None,
|
| 405 |
+
help="Optional path to save results CSV")
|
| 406 |
+
parser.add_argument("--baseline", default=None,
|
| 407 |
+
help="Path to v2 baseline JSON (e.g. misc/v2_baseline.json) for head-to-head comparison")
|
| 408 |
+
args = parser.parse_args()
|
| 409 |
+
|
| 410 |
+
print(f"\nLoading dataset: {args.dataset}")
|
| 411 |
+
test_cases = load_dataset(args.dataset)
|
| 412 |
+
print(f" {len(test_cases)} test cases loaded.")
|
| 413 |
+
|
| 414 |
+
results_by_system: Dict[str, List[Result]] = {}
|
| 415 |
+
a_results: List[Result] = []
|
| 416 |
+
b_results: List[Result] = []
|
| 417 |
+
|
| 418 |
+
decoder = _load_v3_decoder()
|
| 419 |
+
|
| 420 |
+
if args.mode == "ablation":
|
| 421 |
+
print("\nRunning (A) ByT5 top-1 only...")
|
| 422 |
+
a_results = [_score(tc, _byt5_top1_predict(decoder, tc.input), "byt5_top1") for tc in test_cases]
|
| 423 |
+
results_by_system["byt5_top1"] = a_results
|
| 424 |
+
|
| 425 |
+
print("\nRunning (B) ByT5 + MLM reranking...")
|
| 426 |
+
b_results = [_score(tc, decoder.decode(tc.input)[0], "byt5_mlm") for tc in test_cases]
|
| 427 |
+
results_by_system["byt5_mlm"] = b_results
|
| 428 |
+
|
| 429 |
+
if args.mode == "system":
|
| 430 |
+
_print_table("v3 Code-Mixed Pipeline (ByT5 + XLM-RoBERTa MLM)", b_results)
|
| 431 |
+
elif args.mode == "ablation":
|
| 432 |
+
_print_table("(A) ByT5 top-1 only", a_results)
|
| 433 |
+
_print_table("(B) ByT5 + MLM reranking", b_results)
|
| 434 |
+
_print_ablation(a_results, b_results)
|
| 435 |
+
|
| 436 |
+
if args.baseline:
|
| 437 |
+
baseline = _load_baseline(args.baseline)
|
| 438 |
+
_print_v2_comparison(b_results, baseline)
|
| 439 |
+
|
| 440 |
+
if args.out:
|
| 441 |
+
_save_csv(results_by_system, args.out)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
if __name__ == "__main__":
|
| 445 |
+
main()
|
| 446 |
+
|
misc/upload_mlm_to_hf.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upload the fine-tuned XLM-RoBERTa MLM model to HuggingFace Hub.
|
| 3 |
+
Run from: C:\Y5_Docs\FYP\SinCode\SinCode_v3
|
| 4 |
+
Usage: python misc/upload_mlm_to_hf.py --token YOUR_HF_WRITE_TOKEN
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from huggingface_hub import HfApi
|
| 10 |
+
|
| 11 |
+
MODEL_LOCAL_PATH = Path(
|
| 12 |
+
r"C:\Y5_Docs\FYP\SinCode\SinCode_v2-20260315T161648Z-1-001"
|
| 13 |
+
r"\SinCode_v2\SinCode\SinCode\xlm-roberta-sinhala-v5-strict-full\final"
|
| 14 |
+
)
|
| 15 |
+
REPO_ID = "Kalana001/xlm-roberta-base-finetuned-sinhala"
|
| 16 |
+
|
| 17 |
+
FILES_TO_UPLOAD = [
|
| 18 |
+
"config.json",
|
| 19 |
+
"model.safetensors",
|
| 20 |
+
"tokenizer.json",
|
| 21 |
+
"tokenizer_config.json",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--token", required=True, help="HuggingFace write-access token")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
api = HfApi(token=args.token)
|
| 30 |
+
|
| 31 |
+
print(f"Uploading to: {REPO_ID}")
|
| 32 |
+
for filename in FILES_TO_UPLOAD:
|
| 33 |
+
local_file = MODEL_LOCAL_PATH / filename
|
| 34 |
+
if not local_file.exists():
|
| 35 |
+
print(f" SKIP (not found): {filename}")
|
| 36 |
+
continue
|
| 37 |
+
size_mb = round(local_file.stat().st_size / 1024 / 1024, 1)
|
| 38 |
+
print(f" Uploading {filename} ({size_mb} MB)...")
|
| 39 |
+
api.upload_file(
|
| 40 |
+
path_or_fileobj=str(local_file),
|
| 41 |
+
path_in_repo=filename,
|
| 42 |
+
repo_id=REPO_ID,
|
| 43 |
+
repo_type="model",
|
| 44 |
+
)
|
| 45 |
+
print(f" Done: {filename}")
|
| 46 |
+
|
| 47 |
+
print(f"\nAll files uploaded to https://huggingface.co/{REPO_ID}")
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
misc/v2_baseline.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"system": "SinCode v2 (rule-based + dictionary)",
|
| 3 |
+
"samples": 110,
|
| 4 |
+
"exact_match": 0.8364,
|
| 5 |
+
"cer": 0.0122,
|
| 6 |
+
"wer": 0.0407,
|
| 7 |
+
"bleu": 0.8861,
|
| 8 |
+
"token_acc": 0.9593,
|
| 9 |
+
"notes": "Measured on dataset_110.csv test split. Avg time per sentence: 0.03s (3.34s total)."
|
| 10 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.40.0
|
| 2 |
+
torch>=2.2.0
|
| 3 |
+
sentencepiece
|
| 4 |
+
datasets
|
| 5 |
+
streamlit
|
| 6 |
+
pandas
|
seq2seq/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# seq2seq — ByT5 and mBart50 inference wrappers
|
seq2seq/compose_fix_map.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"ක්\\sර": "ක්ර",
|
| 3 |
+
"ක්\\sරි": "ක්රි",
|
| 4 |
+
"ක්\\sරී": "ක්රී",
|
| 5 |
+
"ක්\\sරා": "ක්රා",
|
| 6 |
+
"ද්\\sර": "ද්ර",
|
| 7 |
+
"ද්\\sරි": "ද්රි",
|
| 8 |
+
"ද්\\sරී": "ද්රී",
|
| 9 |
+
"ද්\\sරා": "ද්රා",
|
| 10 |
+
"ට්\\sර": "ට්ර",
|
| 11 |
+
"ට්\\sරි": "ට්රි",
|
| 12 |
+
"ට්\\sරී": "ට්රී",
|
| 13 |
+
"ට්\\sරා": "ට්රා",
|
| 14 |
+
"ත්\\sර": "ත්ර",
|
| 15 |
+
"ත්\\sරි": "ත්රි",
|
| 16 |
+
"ත්\\sරී": "ත්රී",
|
| 17 |
+
"ත්\\sරා": "ත්රා",
|
| 18 |
+
"ප්\\sර": "ප්ර",
|
| 19 |
+
"ප්\\sරි": "ප්රි",
|
| 20 |
+
"ප්\\sරී": "ප්රී",
|
| 21 |
+
"ප්\\sරා": "ප්රා",
|
| 22 |
+
"බ්\\sර": "බ්ර",
|
| 23 |
+
"බ්\\sරි": "බ්රි",
|
| 24 |
+
"බ්\\sරී": "බ්රී",
|
| 25 |
+
"බ්\\sරා": "බ්රා",
|
| 26 |
+
"ග්\\sර": "ග්ර",
|
| 27 |
+
"ග්\\sරි": "ග්රි",
|
| 28 |
+
"ග්\\sරී": "ග්රී",
|
| 29 |
+
"ග්\\sරා": "ග්රා",
|
| 30 |
+
"ෂ්\\sර": "ෂ්ර",
|
| 31 |
+
"ෂ්\\sරි": "ෂ්රි",
|
| 32 |
+
"ෂ්\\sරී": "ෂ්රී",
|
| 33 |
+
"ෂ්\\sරා": "ෂ්රා",
|
| 34 |
+
"ශ්\\sර": "ශ්ර",
|
| 35 |
+
"ශ්\\sරි": "ශ්රි",
|
| 36 |
+
"ශ්\\sරී": "ශ්රී",
|
| 37 |
+
"ශ්\\sරා": "ශ්රා",
|
| 38 |
+
"ව්\\sය": "ව්ය",
|
| 39 |
+
"ව්\\sයා": "ව්යා",
|
| 40 |
+
"ද්\\sය": "ද්ය",
|
| 41 |
+
"ද්\\sයා": "ද්යා",
|
| 42 |
+
"න්\\sය": "න්ය",
|
| 43 |
+
"ධ්\\sය": "ධ්ය",
|
| 44 |
+
"ධ්\\sයා": "ධ්යා",
|
| 45 |
+
"ද්\\sයු": "ද්යු"
|
| 46 |
+
}
|
seq2seq/finetune_corrections.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
seq2seq/finetune_corrections.py
|
| 3 |
+
|
| 4 |
+
Targeted correction fine-tune for the already-trained ByT5 model.
|
| 5 |
+
|
| 6 |
+
Problem: ByT5 struggles with short/ambiguous tokens like "na"→නෑ, "ba"→බෑ,
|
| 7 |
+
extreme abbreviations like "mn"→මං, and colloquial negations.
|
| 8 |
+
|
| 9 |
+
Solution: Inject high-confidence correction pairs (from core/mappings.py)
|
| 10 |
+
heavily repeated, mixed with a random sample of the original
|
| 11 |
+
training data to prevent catastrophic forgetting.
|
| 12 |
+
|
| 13 |
+
The output is saved to byt5-singlish-sinhala/final/ (overwrites in place).
|
| 14 |
+
Run from the project root:
|
| 15 |
+
python seq2seq/finetune_corrections.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import random
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
ROOT = Path(__file__).parent.parent
|
| 25 |
+
if str(ROOT) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(ROOT))
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from datasets import Dataset
|
| 30 |
+
from transformers import (
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
AutoModelForSeq2SeqLM,
|
| 33 |
+
Seq2SeqTrainer,
|
| 34 |
+
Seq2SeqTrainingArguments,
|
| 35 |
+
default_data_collator,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# ── Config ────────────────────────────────────────────────────────────────────
|
| 39 |
+
|
| 40 |
+
MODEL_PATH = ROOT / "seq2seq" / "byt5-singlish-sinhala" / "final"
|
| 41 |
+
DATA_PATH = ROOT / "seq2seq" / "wsd_pairs.csv"
|
| 42 |
+
OUTPUT_DIR = ROOT / "seq2seq" / "byt5-singlish-sinhala" / "final" # overwrite in place
|
| 43 |
+
|
| 44 |
+
REPEAT = 500 # how many times each correction pair is repeated
|
| 45 |
+
BG_SAMPLES = 50_000 # random background pairs from wsd_pairs.csv to prevent forgetting
|
| 46 |
+
MAX_INPUT_LEN = 64
|
| 47 |
+
MAX_TARGET_LEN = 64
|
| 48 |
+
BATCH_SIZE = 32
|
| 49 |
+
LR = 5e-5 # low LR — gentle correction, not retraining
|
| 50 |
+
EPOCHS = 1
|
| 51 |
+
SEED = 42
|
| 52 |
+
|
| 53 |
+
# ── Correction pairs (sourced from core/mappings.py) ─────────────────────────
|
| 54 |
+
# Only include pairs where ByT5 is known to be unreliable.
|
| 55 |
+
# English-safe tokens (pr, dm, ai…) are excluded — they never reach ByT5.
|
| 56 |
+
|
| 57 |
+
CORRECTIONS = [
|
| 58 |
+
# negation — most critical
|
| 59 |
+
("na", "නෑ"),
|
| 60 |
+
("naa", "නෑ"),
|
| 61 |
+
("ba", "බෑ"),
|
| 62 |
+
("bari", "බැරි"),
|
| 63 |
+
("bri", "බැරි"),
|
| 64 |
+
("nathi", "නැති"),
|
| 65 |
+
("nati", "නැති"),
|
| 66 |
+
("naththe", "නැත්තෙ"),
|
| 67 |
+
("epa", "එපා"),
|
| 68 |
+
("ep", "එපා"),
|
| 69 |
+
# pronouns / first person
|
| 70 |
+
("mn", "මං"),
|
| 71 |
+
("mama", "මම"),
|
| 72 |
+
("mage", "මගේ"),
|
| 73 |
+
("mge", "මගේ"),
|
| 74 |
+
("oya", "ඔයා"),
|
| 75 |
+
("oyaa", "ඔයා"),
|
| 76 |
+
("api", "අපි"),
|
| 77 |
+
("mata", "මට"),
|
| 78 |
+
("mta", "මට"),
|
| 79 |
+
("oyata", "ඔයාට"),
|
| 80 |
+
("oyta", "ඔයාට"),
|
| 81 |
+
("oyage", "ඔයාගේ"),
|
| 82 |
+
("oyge", "ඔයාගෙ"),
|
| 83 |
+
("ape", "අපේ"),
|
| 84 |
+
# common particles
|
| 85 |
+
("one", "ඕනෙ"),
|
| 86 |
+
("oney", "ඕනේ"),
|
| 87 |
+
("on", "ඕනෙ"),
|
| 88 |
+
("oni", "ඕනි"),
|
| 89 |
+
("hari", "හරි"),
|
| 90 |
+
("hri", "හරි"),
|
| 91 |
+
("wage", "වගේ"),
|
| 92 |
+
("nisa", "නිසා"),
|
| 93 |
+
("dan", "දැන්"),
|
| 94 |
+
("gena", "ගැන"),
|
| 95 |
+
# time
|
| 96 |
+
("heta", "හෙට"),
|
| 97 |
+
("hta", "හෙට"),
|
| 98 |
+
("ada", "අද"),
|
| 99 |
+
("iye", "ඊයේ"),
|
| 100 |
+
("kalin", "කලින්"),
|
| 101 |
+
("passe", "පස්සෙ"),
|
| 102 |
+
# abbreviations
|
| 103 |
+
("mn", "මං"),
|
| 104 |
+
("ek", "එක"),
|
| 105 |
+
("ekta", "එකට"),
|
| 106 |
+
("eke", "එකේ"),
|
| 107 |
+
("me", "මේ"),
|
| 108 |
+
# common words
|
| 109 |
+
("honda", "හොඳ"),
|
| 110 |
+
("hodai", "හොඳයි"),
|
| 111 |
+
("gedara", "ගෙදර"),
|
| 112 |
+
("wada", "වැඩ"),
|
| 113 |
+
("kema", "කෑම"),
|
| 114 |
+
("kama", "කෑම"),
|
| 115 |
+
("inne", "ඉන්නෙ"),
|
| 116 |
+
("inna", "ඉන්න"),
|
| 117 |
+
("madi", "මදි"),
|
| 118 |
+
("iwara", "ඉවර"),
|
| 119 |
+
("iwra", "ඉවර"),
|
| 120 |
+
# verbal
|
| 121 |
+
("awa", "ආවා"),
|
| 122 |
+
("aawa", "ආවා"),
|
| 123 |
+
("giya", "ගියා"),
|
| 124 |
+
("una", "උනා"),
|
| 125 |
+
("wuna", "උනා"),
|
| 126 |
+
("kiwa", "කිව්වා"),
|
| 127 |
+
("kiwwa", "කිව්වා"),
|
| 128 |
+
("yewwa", "යැව්වා"),
|
| 129 |
+
("yawwa", "යැව්වා"),
|
| 130 |
+
("damma", "දැම්මා"),
|
| 131 |
+
("karanna", "කරන්න"),
|
| 132 |
+
("krnna", "කරන්න"),
|
| 133 |
+
("balanna", "බලන්න"),
|
| 134 |
+
("blnna", "බලන්න"),
|
| 135 |
+
("hadanna", "හදන්න"),
|
| 136 |
+
("karamu", "කරමු"),
|
| 137 |
+
("balamu", "බලමු"),
|
| 138 |
+
("yamu", "යමු"),
|
| 139 |
+
("hadamu", "හදමු"),
|
| 140 |
+
("damu", "දාමු"),
|
| 141 |
+
("wenawa", "වෙනව��"),
|
| 142 |
+
("wenwa", "වෙනවා"),
|
| 143 |
+
("thiyanawa", "තියෙනවා"),
|
| 144 |
+
("enawa", "එනවා"),
|
| 145 |
+
("yanawa", "යනවා"),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ── Dataset builder ───────────────────────────────────────────────────────────
|
| 150 |
+
|
| 151 |
+
def build_dataset(tokenizer) -> Dataset:
|
| 152 |
+
import csv
|
| 153 |
+
|
| 154 |
+
pairs: list[dict] = []
|
| 155 |
+
|
| 156 |
+
# 1. Correction pairs repeated REPEAT times
|
| 157 |
+
for romanized, sinhala in CORRECTIONS:
|
| 158 |
+
for _ in range(REPEAT):
|
| 159 |
+
pairs.append({"romanized": romanized, "sinhala": sinhala})
|
| 160 |
+
|
| 161 |
+
correction_count = len(pairs)
|
| 162 |
+
print(f" Correction pairs: {len(CORRECTIONS)} × {REPEAT} = {correction_count:,}")
|
| 163 |
+
|
| 164 |
+
# 2. Background sample from original training data
|
| 165 |
+
bg: list[dict] = []
|
| 166 |
+
with open(DATA_PATH, encoding="utf-8", newline="") as f:
|
| 167 |
+
reader = csv.DictReader(f)
|
| 168 |
+
for row in reader:
|
| 169 |
+
r = (row.get("romanized") or "").strip()
|
| 170 |
+
s = (row.get("sinhala") or "").strip()
|
| 171 |
+
if r and s:
|
| 172 |
+
bg.append({"romanized": r, "sinhala": s})
|
| 173 |
+
|
| 174 |
+
random.seed(SEED)
|
| 175 |
+
random.shuffle(bg)
|
| 176 |
+
bg = bg[:BG_SAMPLES]
|
| 177 |
+
pairs.extend(bg)
|
| 178 |
+
print(f" Background pairs: {len(bg):,}")
|
| 179 |
+
print(f" Total dataset : {len(pairs):,}")
|
| 180 |
+
|
| 181 |
+
random.shuffle(pairs)
|
| 182 |
+
|
| 183 |
+
ds = Dataset.from_list(pairs)
|
| 184 |
+
|
| 185 |
+
def tokenize(batch):
|
| 186 |
+
inputs = tokenizer(
|
| 187 |
+
batch["romanized"],
|
| 188 |
+
max_length=MAX_INPUT_LEN,
|
| 189 |
+
truncation=True,
|
| 190 |
+
padding="max_length",
|
| 191 |
+
)
|
| 192 |
+
targets = tokenizer(
|
| 193 |
+
batch["sinhala"],
|
| 194 |
+
max_length=MAX_TARGET_LEN,
|
| 195 |
+
truncation=True,
|
| 196 |
+
padding="max_length",
|
| 197 |
+
)
|
| 198 |
+
inputs["labels"] = [
|
| 199 |
+
[(t if t != tokenizer.pad_token_id else -100) for t in ids]
|
| 200 |
+
for ids in targets["input_ids"]
|
| 201 |
+
]
|
| 202 |
+
return inputs
|
| 203 |
+
|
| 204 |
+
ds = ds.map(tokenize, batched=True, batch_size=5_000,
|
| 205 |
+
remove_columns=["romanized", "sinhala"], desc="Tokenizing")
|
| 206 |
+
ds.set_format("torch")
|
| 207 |
+
return ds
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
| 211 |
+
|
| 212 |
+
def main():
|
| 213 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 214 |
+
print(f"\nDevice : {device}")
|
| 215 |
+
if device == "cpu":
|
| 216 |
+
print("WARNING: running on CPU — this will take ~30-60 min.")
|
| 217 |
+
|
| 218 |
+
print(f"Loading model from {MODEL_PATH} ...")
|
| 219 |
+
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
|
| 220 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(str(MODEL_PATH))
|
| 221 |
+
|
| 222 |
+
print("\nBuilding correction dataset ...")
|
| 223 |
+
ds = build_dataset(tokenizer)
|
| 224 |
+
|
| 225 |
+
split = ds.train_test_split(test_size=0.02, seed=SEED)
|
| 226 |
+
train_ds = split["train"]
|
| 227 |
+
eval_ds = split["test"]
|
| 228 |
+
print(f" train={len(train_ds):,} eval={len(eval_ds):,}")
|
| 229 |
+
|
| 230 |
+
warmup = max(100, len(train_ds) // (BATCH_SIZE * 20))
|
| 231 |
+
|
| 232 |
+
args = Seq2SeqTrainingArguments(
|
| 233 |
+
output_dir=str(OUTPUT_DIR),
|
| 234 |
+
num_train_epochs=EPOCHS,
|
| 235 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 236 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 237 |
+
learning_rate=LR,
|
| 238 |
+
warmup_steps=warmup,
|
| 239 |
+
weight_decay=0.01,
|
| 240 |
+
predict_with_generate=False, # faster eval — we only care about loss
|
| 241 |
+
eval_strategy="epoch",
|
| 242 |
+
save_strategy="epoch",
|
| 243 |
+
load_best_model_at_end=True,
|
| 244 |
+
metric_for_best_model="eval_loss",
|
| 245 |
+
logging_steps=100,
|
| 246 |
+
dataloader_num_workers=0,
|
| 247 |
+
seed=SEED,
|
| 248 |
+
bf16=torch.cuda.is_bf16_supported(),
|
| 249 |
+
fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
trainer = Seq2SeqTrainer(
|
| 253 |
+
model=model,
|
| 254 |
+
args=args,
|
| 255 |
+
train_dataset=train_ds,
|
| 256 |
+
eval_dataset=eval_ds,
|
| 257 |
+
data_collator=default_data_collator,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
print("\nStarting correction fine-tune ...")
|
| 261 |
+
trainer.train()
|
| 262 |
+
|
| 263 |
+
print(f"\nSaving corrected model to {OUTPUT_DIR} ...")
|
| 264 |
+
model.save_pretrained(str(OUTPUT_DIR))
|
| 265 |
+
tokenizer.save_pretrained(str(OUTPUT_DIR))
|
| 266 |
+
print("Done.")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
main()
|
seq2seq/infer.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference helper — given a romanized word, return top-K Sinhala candidates
|
| 3 |
+
using beam search on the fine-tuned ByT5 model.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
from seq2seq.infer import Transliterator
|
| 7 |
+
t = Transliterator()
|
| 8 |
+
print(t.candidates("videowe", k=5))
|
| 9 |
+
# ['වීඩියොවේ', 'වීඩියොවී', 'වීඩියොව', ...]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
| 18 |
+
|
| 19 |
+
DEFAULT_MODEL_PATH = Path(__file__).parent / "byt5-singlish-sinhala" / "final"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Transliterator:
|
| 23 |
+
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: Optional[str] = None):
|
| 24 |
+
# Keep as string — Path() would convert '/' to '\' on Windows, breaking HF Hub IDs
|
| 25 |
+
model_path = str(model_path)
|
| 26 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
self.tokenizer = ByT5Tokenizer.from_pretrained(model_path)
|
| 28 |
+
self.model = T5ForConditionalGeneration.from_pretrained(model_path)
|
| 29 |
+
self.model.to(self.device)
|
| 30 |
+
self.model.eval()
|
| 31 |
+
|
| 32 |
+
def candidates(self, word: str, k: int = 5) -> list[str]:
|
| 33 |
+
"""Return top-k Sinhala transliteration candidates for a single word."""
|
| 34 |
+
return self.batch_candidates([word], k=k)[0]
|
| 35 |
+
|
| 36 |
+
def batch_candidates(self, words: list[str], k: int = 5) -> list[list[str]]:
|
| 37 |
+
"""
|
| 38 |
+
Return top-k Sinhala candidates for each word in a single forward pass.
|
| 39 |
+
Much faster than calling candidates() per word on a long sentence.
|
| 40 |
+
"""
|
| 41 |
+
lowered = [w.lower() for w in words]
|
| 42 |
+
inputs = self.tokenizer(
|
| 43 |
+
lowered,
|
| 44 |
+
return_tensors="pt",
|
| 45 |
+
padding=True,
|
| 46 |
+
truncation=True,
|
| 47 |
+
max_length=64,
|
| 48 |
+
).to(self.device)
|
| 49 |
+
|
| 50 |
+
n = len(words)
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
outputs = self.model.generate(
|
| 53 |
+
**inputs,
|
| 54 |
+
num_beams=max(k, 5),
|
| 55 |
+
num_return_sequences=k,
|
| 56 |
+
max_new_tokens=64,
|
| 57 |
+
early_stopping=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# outputs shape: (n * k, seq_len) — k sequences per input, grouped
|
| 61 |
+
results: list[list[str]] = []
|
| 62 |
+
for i in range(n):
|
| 63 |
+
seen: set[str] = set()
|
| 64 |
+
cands: list[str] = []
|
| 65 |
+
for seq in outputs[i * k : (i + 1) * k]:
|
| 66 |
+
text = self.tokenizer.decode(seq, skip_special_tokens=True).strip()
|
| 67 |
+
if text and text not in seen:
|
| 68 |
+
seen.add(text)
|
| 69 |
+
cands.append(text)
|
| 70 |
+
results.append(cands)
|
| 71 |
+
|
| 72 |
+
return results
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
import sys
|
| 77 |
+
words = sys.argv[1:] if len(sys.argv) > 1 else ["wadi"]
|
| 78 |
+
t = Transliterator()
|
| 79 |
+
for word in words:
|
| 80 |
+
print(f"Candidates for '{word}':")
|
| 81 |
+
for c in t.candidates(word):
|
| 82 |
+
print(f" {c}")
|
| 83 |
+
print()
|
seq2seq/mbart_infer.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mBart50-based Sentence Transliterator for SinCode v3.
|
| 3 |
+
|
| 4 |
+
Full-sentence Singlish → Sinhala transliteration.
|
| 5 |
+
Unlike the ByT5 word-by-word pipeline, mBart50 operates on the whole input
|
| 6 |
+
sentence and produces fully Sinhalized output — no English words are retained.
|
| 7 |
+
|
| 8 |
+
Use-case: "mn heta business ekak start karanawa"
|
| 9 |
+
→ "මන් හෙට ව්යාපාරයක් පටන් ගන්නවා"
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import re
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import MBart50Tokenizer, MBartForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from core.constants import DEFAULT_MBART_MODEL
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# ── Fix-map (ZWJ / Virama composition) ───────────────────────────────────────
|
| 28 |
+
|
| 29 |
+
_FIX_MAP_PATH = Path(__file__).parent / "compose_fix_map.json"
|
| 30 |
+
|
| 31 |
+
_fix_map_cache: dict[str, str] | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _load_fix_map() -> dict[str, str]:
|
| 35 |
+
global _fix_map_cache
|
| 36 |
+
if _fix_map_cache is None:
|
| 37 |
+
with open(_FIX_MAP_PATH, "r", encoding="utf-8") as f:
|
| 38 |
+
_fix_map_cache = json.load(f)
|
| 39 |
+
return _fix_map_cache
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ── Input cleaning ────────────────────────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
# Scripts that are not Sinhala, Latin, numbers, or symbols — filtered out
|
| 45 |
+
_UNSUPPORTED_SCRIPT = re.compile(
|
| 46 |
+
r"[\u0B80-\u0BFF" # Tamil
|
| 47 |
+
r"\u0900-\u097F" # Devanagari
|
| 48 |
+
r"\u4E00-\u9FFF" # CJK Unified Ideographs
|
| 49 |
+
r"\u3040-\u309F" # Hiragana
|
| 50 |
+
r"\u30A0-\u30FF" # Katakana
|
| 51 |
+
r"\u0E00-\u0E7F" # Thai
|
| 52 |
+
r"\u0600-\u06FF" # Arabic
|
| 53 |
+
r"\u0590-\u05FF" # Hebrew
|
| 54 |
+
r"\uAC00-\uD7AF]" # Hangul
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _clean(text: str) -> str | None:
|
| 59 |
+
"""Remove words in unsupported scripts; return None if nothing remains."""
|
| 60 |
+
words = text.strip().split()
|
| 61 |
+
filtered = [w for w in words if not _UNSUPPORTED_SCRIPT.search(w)]
|
| 62 |
+
return " ".join(filtered) if filtered else None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _apply_fixes(text: str) -> str:
|
| 66 |
+
"""Apply ZWJ/virama composition fixes to mBart50 output."""
|
| 67 |
+
for pattern, replacement in _load_fix_map().items():
|
| 68 |
+
text = re.sub(pattern, replacement, text)
|
| 69 |
+
return text
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ── Transliterator ────────────────────────────────────────────────────────────
|
| 73 |
+
|
| 74 |
+
class SentenceTransliterator:
|
| 75 |
+
"""
|
| 76 |
+
Full-sentence Singlish → Sinhala transliterator (mBart50).
|
| 77 |
+
|
| 78 |
+
Loads from Hugging Face Hub on first instantiation.
|
| 79 |
+
Thread-safe for inference (no mutable state after __init__).
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
model_name: str = DEFAULT_MBART_MODEL,
|
| 85 |
+
device: Optional[str] = None,
|
| 86 |
+
):
|
| 87 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
+
|
| 89 |
+
logger.info("Loading mBart50 transliterator: %s", model_name)
|
| 90 |
+
self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
|
| 91 |
+
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
|
| 92 |
+
self.model.to(self.device)
|
| 93 |
+
self.model.eval()
|
| 94 |
+
|
| 95 |
+
def transliterate(self, text: str) -> str:
|
| 96 |
+
"""
|
| 97 |
+
Transliterate a Singlish sentence to fully-Sinhalized output.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
text: Input Singlish sentence (Romanized Sinhala / English mix).
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Sinhala-script output. Returns original text if input is empty
|
| 104 |
+
or consists entirely of unsupported-script characters.
|
| 105 |
+
"""
|
| 106 |
+
cleaned = _clean(text)
|
| 107 |
+
if not cleaned:
|
| 108 |
+
return text
|
| 109 |
+
|
| 110 |
+
self.tokenizer.src_lang = "si_LK"
|
| 111 |
+
inputs = self.tokenizer(
|
| 112 |
+
cleaned,
|
| 113 |
+
return_tensors="pt",
|
| 114 |
+
padding=True,
|
| 115 |
+
truncation=True,
|
| 116 |
+
max_length=128,
|
| 117 |
+
).to(self.device)
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
tokens = self.model.generate(
|
| 121 |
+
**inputs,
|
| 122 |
+
forced_bos_token_id=self.tokenizer.lang_code_to_id["si_LK"],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
output = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
|
| 126 |
+
return _apply_fixes(output)
|
seq2seq/prepare_data.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parse WSD.txt into a CSV training dataset for ByT5 fine-tuning.
|
| 3 |
+
|
| 4 |
+
Input format (WSD.txt):
|
| 5 |
+
Word: <romanized>, Sinhala Words: ['<s1>', '<s2>', ...]
|
| 6 |
+
|
| 7 |
+
Output (wsd_pairs.csv):
|
| 8 |
+
romanized,sinhala
|
| 9 |
+
wadi,වෑඩි
|
| 10 |
+
wadi,වාඩි
|
| 11 |
+
...
|
| 12 |
+
|
| 13 |
+
One row per (romanized, sinhala) pair. Duplicate sinhala entries per
|
| 14 |
+
word are kept since ByT5 learns from all valid transliterations.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import ast
|
| 18 |
+
import csv
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
WSD_PATH = Path(r"C:\Y5_Docs\FYP\WSD.txt")
|
| 24 |
+
OUT_PATH = Path(__file__).parent / "wsd_pairs.csv"
|
| 25 |
+
|
| 26 |
+
LINE_RE = re.compile(r"^Word:\s*(.+?),\s*Sinhala Words:\s*(\[.+\])\s*$")
|
| 27 |
+
|
| 28 |
+
MIN_ROMAN_LEN = 2 # skip single-char romanized entries
|
| 29 |
+
MAX_ROMAN_LEN = 40 # skip obviously malformed long entries
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_wsd(wsd_path: Path) -> list[tuple[str, str]]:
|
| 33 |
+
pairs: list[tuple[str, str]] = []
|
| 34 |
+
skipped = 0
|
| 35 |
+
|
| 36 |
+
with wsd_path.open(encoding="utf-8") as f:
|
| 37 |
+
for lineno, line in enumerate(f, 1):
|
| 38 |
+
line = line.strip()
|
| 39 |
+
if not line:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
m = LINE_RE.match(line)
|
| 43 |
+
if not m:
|
| 44 |
+
skipped += 1
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
roman = m.group(1).strip().lower()
|
| 48 |
+
if not (MIN_ROMAN_LEN <= len(roman) <= MAX_ROMAN_LEN):
|
| 49 |
+
skipped += 1
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
sinhala_list = ast.literal_eval(m.group(2))
|
| 54 |
+
except (ValueError, SyntaxError):
|
| 55 |
+
skipped += 1
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
for sinhala in sinhala_list:
|
| 59 |
+
sinhala = sinhala.strip()
|
| 60 |
+
if sinhala:
|
| 61 |
+
pairs.append((roman, sinhala))
|
| 62 |
+
|
| 63 |
+
if lineno % 100_000 == 0:
|
| 64 |
+
print(f" processed {lineno:,} lines, {len(pairs):,} pairs so far…")
|
| 65 |
+
|
| 66 |
+
print(f" skipped {skipped:,} malformed lines")
|
| 67 |
+
return pairs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def write_csv(pairs: list[tuple[str, str]], out_path: Path) -> None:
|
| 71 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
with out_path.open("w", encoding="utf-8", newline="") as f:
|
| 73 |
+
writer = csv.writer(f)
|
| 74 |
+
writer.writerow(["romanized", "sinhala"])
|
| 75 |
+
writer.writerows(pairs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main() -> None:
|
| 79 |
+
print(f"Parsing {WSD_PATH} …")
|
| 80 |
+
pairs = parse_wsd(WSD_PATH)
|
| 81 |
+
print(f"\nTotal pairs: {len(pairs):,}")
|
| 82 |
+
|
| 83 |
+
print(f"Writing to {OUT_PATH} …")
|
| 84 |
+
write_csv(pairs, OUT_PATH)
|
| 85 |
+
print("Done.")
|
| 86 |
+
|
| 87 |
+
# Quick sanity check
|
| 88 |
+
print("\nSample rows:")
|
| 89 |
+
for roman, sinhala in pairs[:5]:
|
| 90 |
+
print(f" {roman!r:20s} → {sinhala}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
seq2seq/train.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fine-tune google/byt5-small on Singlish → Sinhala word-level transliteration.
|
| 3 |
+
|
| 4 |
+
Input: wsd_pairs.csv (romanized, sinhala)
|
| 5 |
+
Output: byt5-singlish-sinhala/ (HuggingFace model directory)
|
| 6 |
+
|
| 7 |
+
Training approach:
|
| 8 |
+
- Input : romanized word (e.g. "wadi")
|
| 9 |
+
- Target : sinhala word (e.g. "වැඩි")
|
| 10 |
+
- Model : ByT5-small (byte-level T5, no vocab issues with any script)
|
| 11 |
+
- Beam=5 at inference → top-5 candidates for MLM reranking
|
| 12 |
+
|
| 13 |
+
Tokenized dataset is saved to disk after first run — restarts skip
|
| 14 |
+
straight to training without re-tokenizing.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from datasets import Dataset, load_from_disk
|
| 21 |
+
from transformers import (
|
| 22 |
+
AutoTokenizer,
|
| 23 |
+
AutoModelForSeq2SeqLM,
|
| 24 |
+
Seq2SeqTrainer,
|
| 25 |
+
Seq2SeqTrainingArguments,
|
| 26 |
+
default_data_collator,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# ── Config ─────────────────────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
BASE_MODEL = "google/byt5-small"
|
| 32 |
+
DATA_PATH = Path(__file__).parent / "wsd_pairs.csv"
|
| 33 |
+
CACHE_DIR = Path(__file__).parent / "tokenized_cache"
|
| 34 |
+
OUTPUT_DIR = Path(__file__).parent / "byt5-singlish-sinhala"
|
| 35 |
+
|
| 36 |
+
MAX_SAMPLES = 1_000_000 # 1M pairs — more than enough for word transliteration
|
| 37 |
+
TRAIN_SPLIT = 0.97
|
| 38 |
+
MAX_INPUT_LEN = 64
|
| 39 |
+
MAX_TARGET_LEN = 64
|
| 40 |
+
BATCH_SIZE = 64 # 16GB VRAM — ByT5-small with seq_len=64
|
| 41 |
+
EPOCHS = 2
|
| 42 |
+
LR = 5e-4
|
| 43 |
+
SEED = 42
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ── Tokenize ────────────────────────────────────────────────────────────────
|
| 47 |
+
|
| 48 |
+
def tokenize_fn(batch, tokenizer):
|
| 49 |
+
# Pad to fixed max_length so all tensors have the same shape.
|
| 50 |
+
# This lets set_format("torch") work and default_data_collator just stacks.
|
| 51 |
+
model_inputs = tokenizer(
|
| 52 |
+
batch["romanized"],
|
| 53 |
+
max_length=MAX_INPUT_LEN,
|
| 54 |
+
truncation=True,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
)
|
| 57 |
+
labels = tokenizer(
|
| 58 |
+
batch["sinhala"],
|
| 59 |
+
max_length=MAX_TARGET_LEN,
|
| 60 |
+
truncation=True,
|
| 61 |
+
padding="max_length",
|
| 62 |
+
)
|
| 63 |
+
# Replace pad token with -100 so it's ignored in cross-entropy loss
|
| 64 |
+
model_inputs["labels"] = [
|
| 65 |
+
[(t if t != tokenizer.pad_token_id else -100) for t in ids]
|
| 66 |
+
for ids in labels["input_ids"]
|
| 67 |
+
]
|
| 68 |
+
return model_inputs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ── Main ───────────────────────────────────────────────────────────────────
|
| 72 |
+
|
| 73 |
+
def main():
|
| 74 |
+
import os
|
| 75 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 76 |
+
|
| 77 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
+
print(f" Device : {device}")
|
| 79 |
+
if device == "cuda":
|
| 80 |
+
print(f" GPU : {torch.cuda.get_device_name(0)}")
|
| 81 |
+
print(f" VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 82 |
+
else:
|
| 83 |
+
print(" WARNING: No GPU detected — training will be very slow!")
|
| 84 |
+
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 86 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
|
| 87 |
+
|
| 88 |
+
train_cache = CACHE_DIR / "train"
|
| 89 |
+
eval_cache = CACHE_DIR / "eval"
|
| 90 |
+
|
| 91 |
+
if train_cache.exists() and eval_cache.exists():
|
| 92 |
+
print("Loading pre-tokenized dataset from disk cache …")
|
| 93 |
+
train_ds = load_from_disk(str(train_cache))
|
| 94 |
+
eval_ds = load_from_disk(str(eval_cache))
|
| 95 |
+
print(f" train={len(train_ds):,} eval={len(eval_ds):,}")
|
| 96 |
+
else:
|
| 97 |
+
print(f"Loading data from {DATA_PATH} …")
|
| 98 |
+
ds = Dataset.from_csv(str(DATA_PATH))
|
| 99 |
+
ds = ds.filter(lambda x: bool(x["romanized"]) and bool(x["sinhala"]))
|
| 100 |
+
print(f" {len(ds):,} pairs — sampling {MAX_SAMPLES:,} …")
|
| 101 |
+
|
| 102 |
+
# Shuffle and take MAX_SAMPLES
|
| 103 |
+
ds = ds.shuffle(seed=SEED).select(range(min(MAX_SAMPLES, len(ds))))
|
| 104 |
+
|
| 105 |
+
split = ds.train_test_split(test_size=1 - TRAIN_SPLIT, seed=SEED)
|
| 106 |
+
train_raw = split["train"]
|
| 107 |
+
eval_raw = split["test"]
|
| 108 |
+
print(f" train={len(train_raw):,} eval={len(eval_raw):,}")
|
| 109 |
+
|
| 110 |
+
print("Tokenizing and saving to disk (one-time, ~5 min) …")
|
| 111 |
+
train_ds = train_raw.map(
|
| 112 |
+
lambda b: tokenize_fn(b, tokenizer),
|
| 113 |
+
batched=True,
|
| 114 |
+
batch_size=10_000,
|
| 115 |
+
num_proc=8,
|
| 116 |
+
keep_in_memory=True,
|
| 117 |
+
remove_columns=["romanized", "sinhala"],
|
| 118 |
+
desc="Tokenizing train",
|
| 119 |
+
)
|
| 120 |
+
eval_ds = eval_raw.map(
|
| 121 |
+
lambda b: tokenize_fn(b, tokenizer),
|
| 122 |
+
batched=True,
|
| 123 |
+
batch_size=10_000,
|
| 124 |
+
num_proc=8,
|
| 125 |
+
keep_in_memory=True,
|
| 126 |
+
remove_columns=["romanized", "sinhala"],
|
| 127 |
+
desc="Tokenizing eval",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
train_ds.save_to_disk(str(train_cache))
|
| 132 |
+
eval_ds.save_to_disk(str(eval_cache))
|
| 133 |
+
print(" Saved to disk. Future runs will load instantly.")
|
| 134 |
+
|
| 135 |
+
train_ds.set_format("torch")
|
| 136 |
+
eval_ds.set_format("torch")
|
| 137 |
+
|
| 138 |
+
# All sequences are pre-padded to fixed length — just stack them
|
| 139 |
+
collator = default_data_collator
|
| 140 |
+
warmup_steps = int(0.05 * (len(train_ds) // BATCH_SIZE))
|
| 141 |
+
|
| 142 |
+
args = Seq2SeqTrainingArguments(
|
| 143 |
+
output_dir=str(OUTPUT_DIR),
|
| 144 |
+
num_train_epochs=EPOCHS,
|
| 145 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 146 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
| 147 |
+
learning_rate=LR,
|
| 148 |
+
warmup_steps=warmup_steps,
|
| 149 |
+
weight_decay=0.01,
|
| 150 |
+
predict_with_generate=True,
|
| 151 |
+
eval_strategy="epoch",
|
| 152 |
+
save_strategy="epoch",
|
| 153 |
+
load_best_model_at_end=True,
|
| 154 |
+
metric_for_best_model="eval_loss",
|
| 155 |
+
logging_steps=200,
|
| 156 |
+
dataloader_num_workers=0, # 0 = main process only (most stable on Windows)
|
| 157 |
+
dataloader_pin_memory=True,
|
| 158 |
+
bf16=torch.cuda.is_bf16_supported(),
|
| 159 |
+
fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
|
| 160 |
+
seed=SEED,
|
| 161 |
+
report_to="none",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
trainer = Seq2SeqTrainer(
|
| 165 |
+
model=model,
|
| 166 |
+
args=args,
|
| 167 |
+
train_dataset=train_ds,
|
| 168 |
+
eval_dataset=eval_ds,
|
| 169 |
+
processing_class=tokenizer,
|
| 170 |
+
data_collator=collator,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print("Starting training …")
|
| 174 |
+
trainer.train()
|
| 175 |
+
|
| 176 |
+
print(f"Saving model to {OUTPUT_DIR}/final …")
|
| 177 |
+
model.save_pretrained(OUTPUT_DIR / "final")
|
| 178 |
+
tokenizer.save_pretrained(OUTPUT_DIR / "final")
|
| 179 |
+
print("Done.")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
main()
|
| 184 |
+
|
| 185 |
+
|
sincode_model.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SinCode v3 — public API entry point.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from sincode_model import BeamSearchDecoder
|
| 6 |
+
decoder = BeamSearchDecoder()
|
| 7 |
+
result, logs = decoder.decode("mema videowe bit rate eka godak wadi nisa buffer wenawa")
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from core.decoder import BeamSearchDecoder, ScoredCandidate # noqa: F401
|
| 11 |
+
from core.english import ENGLISH_VOCAB # noqa: F401
|
| 12 |
+
from core.constants import ( # noqa: F401
|
| 13 |
+
DEFAULT_MLM_MODEL, DEFAULT_BYT5_MODEL, DEFAULT_MBART_MODEL,
|
| 14 |
+
MAX_CANDIDATES, MIN_ENGLISH_LEN,
|
| 15 |
+
)
|
| 16 |
+
from seq2seq.mbart_infer import SentenceTransliterator # noqa: F401
|