bhsinghgrid commited on
Commit
0ffbe78
·
verified ·
1 Parent(s): 7f30d22

Update inference cleanup + model card + runtime

Browse files
Files changed (2) hide show
  1. README.md +15 -47
  2. inference.py +517 -85
README.md CHANGED
@@ -9,16 +9,13 @@ tags:
9
  - diffusion
10
  - d3pm
11
  - pytorch
12
- pipeline_tag: text-generation
13
  ---
14
 
15
  # Sanskrit D3PM Paraphrase Model
16
 
17
  Roman/IAST Sanskrit input to Devanagari output using a D3PM cross-attention model.
18
 
19
- This is a **custom PyTorch architecture** (not a native `transformers.AutoModel` checkpoint).
20
- You can still use it in a transformer-like workflow (load once, pass text, get generated text) via `inference_api.py`.
21
-
22
  ## Files Included
23
 
24
  - `best_model.pt` — trained checkpoint
@@ -28,7 +25,6 @@ You can still use it in a transformer-like workflow (load once, pass text, get g
28
  - `handler.py` — Hugging Face Endpoint handler
29
  - `model/`, `diffusion/` — architecture modules
30
  - `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
31
- - `LOCAL_SETUP_GUIDE.md` — full laptop setup and execution guide
32
 
33
  ## Quick Local Test
34
 
@@ -37,52 +33,30 @@ from inference_api import predict
37
  print(predict("dharmo rakṣati rakṣitaḥ")["output"])
38
  ```
39
 
40
- ## Transformer-Style Usage (Recommended)
41
 
42
- Use this model like a transformer pipeline pattern: load once, call `generate(text)` many times.
 
43
 
44
  ```python
45
  import torch
46
  from config import CONFIG
47
- from inference import load_model, _build_tokenizers
 
48
 
49
- cfg = CONFIG
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
51
 
52
- model, cfg = load_model("best_model.pt", cfg, device)
53
- src_tok, tgt_tok = _build_tokenizers(cfg)
54
-
55
- def generate(text: str):
56
- input_ids = torch.tensor([src_tok.encode(text)], dtype=torch.long, device=device)
57
- output_ids = model.generate(
58
- input_ids,
59
- num_steps=cfg["inference"]["num_steps"],
60
- temperature=cfg["inference"]["temperature"],
61
- top_k=cfg["inference"]["top_k"],
62
- repetition_penalty=cfg["inference"]["repetition_penalty"],
63
- diversity_penalty=cfg["inference"]["diversity_penalty"],
64
- )
65
- ids = [x for x in output_ids[0].tolist() if x > 4]
66
- return tgt_tok.decode(ids).strip()
67
-
68
- print(generate("yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"))
69
- ```
70
-
71
- ### Minimal 3-Step Pattern
72
-
73
- 1. `load_model(...)` once at app startup
74
- 2. `encode -> model.generate(...) -> decode` for each request
75
- 3. Reuse loaded model/tokenizers for all requests
76
 
77
- ## About `transformers` Compatibility
 
 
 
 
78
 
79
- - This repo does not expose `config.json` + `model.safetensors` in `transformers` format.
80
- - This is not a PEFT/LoRA adapter repository.
81
- - If you want full `AutoModel`/`pipeline` compatibility, you must create a wrapper architecture and export weights into HF Transformers conventions.
82
- - For production today, use:
83
- - `inference_api.py` for Python apps
84
- - `handler.py` for HF Inference Endpoints
85
- - `space_repo/app.py` for Gradio UI
86
 
87
  ## Endpoint Payload
88
 
@@ -113,9 +87,3 @@ git add .
113
  git commit -m "Initial model release"
114
  git push -u origin main
115
  ```
116
-
117
- ## Full Local Laptop Guide
118
-
119
- For complete setup (training, inference, UI, tasks 1-5, ablation, and deployment), see:
120
-
121
- - `LOCAL_SETUP_GUIDE.md`
 
9
  - diffusion
10
  - d3pm
11
  - pytorch
12
+ pipeline_tag: text2text-generation
13
  ---
14
 
15
  # Sanskrit D3PM Paraphrase Model
16
 
17
  Roman/IAST Sanskrit input to Devanagari output using a D3PM cross-attention model.
18
 
 
 
 
19
  ## Files Included
20
 
21
  - `best_model.pt` — trained checkpoint
 
25
  - `handler.py` — Hugging Face Endpoint handler
26
  - `model/`, `diffusion/` — architecture modules
27
  - `sanskrit_src_tokenizer.json`, `sanskrit_tgt_tokenizer.json` — tokenizers
 
28
 
29
  ## Quick Local Test
30
 
 
33
  print(predict("dharmo rakṣati rakṣitaḥ")["output"])
34
  ```
35
 
36
+ ## Transformer-Style Usage (Custom Runtime)
37
 
38
+ This checkpoint is a custom D3PM architecture (`.pt`), not a native `transformers` `AutoModel` format.
39
+ Use it in a transformer-like way via the provided runtime:
40
 
41
  ```python
42
  import torch
43
  from config import CONFIG
44
+ from inference import load_model, run_inference, _decode_clean
45
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
46
 
 
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ model, cfg = load_model("best_model.pt", CONFIG, device)
49
 
50
+ src_tok = SanskritSourceTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
51
+ tgt_tok = SanskritTargetTokenizer(vocab_size=16000, max_len=cfg["model"]["max_seq_len"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ text = "dharmo rakṣati rakṣitaḥ"
54
+ ids = torch.tensor([src_tok.encode(text)], dtype=torch.long, device=device)
55
+ out = run_inference(model, ids, cfg)
56
+ print(_decode_clean(tgt_tok, out[0].tolist()))
57
+ ```
58
 
59
+ If you need full `transformers` compatibility (`AutoModel.from_pretrained`), export weights to a Hugging Face Transformers model format first.
 
 
 
 
 
 
60
 
61
  ## Endpoint Payload
62
 
 
87
  git commit -m "Initial model release"
88
  git push -u origin main
89
  ```
 
 
 
 
 
 
inference.py CHANGED
@@ -1,122 +1,554 @@
1
- import copy
 
 
 
2
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
- import torch.nn.functional as F
 
 
 
5
 
 
6
  from config import CONFIG
7
 
8
 
9
- def _resolve_device(cfg: dict) -> torch.device:
10
- requested = cfg["training"]["device"]
11
- if requested == "cuda" and not torch.cuda.is_available():
12
- requested = "cpu"
13
- if requested == "mps" and not torch.backends.mps.is_available():
14
- requested = "cpu"
15
- cfg["training"]["device"] = requested
16
- return torch.device(requested)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
- def _build_tokenizers(cfg):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  src_tok = SanskritSourceTokenizer(
23
- vocab_size=cfg["model"].get("src_vocab_size", 16000),
24
- max_len=cfg["model"]["max_seq_len"],
25
  )
26
  tgt_tok = SanskritTargetTokenizer(
27
- vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
28
- max_len=cfg["model"]["max_seq_len"],
29
  )
30
- return src_tok, tgt_tok
31
 
 
 
 
32
 
33
- def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
34
- from model.sanskrit_model import SanskritModel
35
-
36
- cfg = copy.deepcopy(base_cfg)
37
- state = torch.load(ckpt_path, map_location="cpu")
 
 
38
 
39
- emb_key = "model.src_embed.token_emb.weight"
40
- if emb_key in state:
41
- vocab, d_model = state[emb_key].shape
42
- cfg["model"]["src_vocab_size"] = vocab
43
- cfg["model"]["d_model"] = d_model
44
- cfg["model"]["d_ff"] = d_model * 4
 
 
 
45
 
46
- layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")}
47
- if layer_ids:
48
- cfg["model"]["n_layers"] = max(layer_ids) + 1
49
 
50
- pos_key = "model.src_embed.pos_enc.pe"
51
- if pos_key in state:
52
- cfg["model"]["max_seq_len"] = state[pos_key].shape[1]
53
 
54
- d_model = cfg["model"]["d_model"]
55
- n_heads = cfg["model"].get("n_heads", 8)
56
- if d_model % n_heads != 0:
57
- n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0)
58
- cfg["model"]["n_heads"] = n_heads
59
 
60
- model = SanskritModel(cfg).to(device)
61
- model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
62
- model.eval()
63
- return model, cfg
64
 
 
 
 
 
65
 
66
- def run_inference(model, input_ids, cfg):
67
- inf = cfg["inference"]
68
- device = input_ids.device
69
- bsz, seqlen = input_ids.shape
70
- inner = model.model
71
 
72
- total_steps = inner.scheduler.num_timesteps
73
- steps = int(inf["num_steps"])
74
- step_size = max(1, total_steps // max(steps, 1))
75
- timesteps = list(range(total_steps - 1, -1, -step_size))
76
- if timesteps[-1] != 0:
77
- timesteps.append(0)
78
 
79
- x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device)
80
- hint = None
 
 
 
 
 
 
81
 
82
- with torch.no_grad():
83
- for i, t_val in enumerate(timesteps):
84
- is_last = i == len(timesteps) - 1
85
- t = torch.full((bsz,), t_val, dtype=torch.long, device=device)
 
 
86
 
87
- logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True)
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- if inf["repetition_penalty"] != 1.0:
90
- from model.d3pm_model_cross_attention import _apply_repetition_penalty
91
 
92
- logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"]))
93
- if inf["diversity_penalty"] > 0.0:
94
- from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed
 
 
 
 
 
 
95
 
96
- logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"]))
 
 
 
 
 
 
 
 
 
97
 
98
- logits = logits / max(float(inf["temperature"]), 1e-5)
99
- if int(inf["top_k"]) > 0:
100
- from model.d3pm_model_cross_attention import _top_k_filter
 
 
 
 
 
101
 
102
- logits = _top_k_filter(logits, int(inf["top_k"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- probs = F.softmax(logits, dim=-1)
105
- if is_last:
106
- x0_est = torch.argmax(probs, dim=-1)
107
- else:
108
- from model.d3pm_model_cross_attention import _batch_multinomial
109
 
110
- x0_est = _batch_multinomial(probs)
111
- hint = x0_est
 
 
112
 
113
- return x0_est
114
 
 
 
 
 
 
 
 
 
115
 
116
- __all__ = [
117
- "CONFIG",
118
- "_resolve_device",
119
- "_build_tokenizers",
120
- "load_model",
121
- "run_inference",
122
- ]
 
1
+ """
2
+ inference.py
3
+ ============
4
+ Correct D3PM inference for Sanskrit paraphrase generation.
5
 
6
+ The model's forward() takes CLEAN tgt and noises it internally.
7
+ So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
8
+ letting the model noise it and then predict a cleaner version.
9
+
10
+ Also includes: robust checkpoint loading (auto-detects architecture
11
+ from saved weights — no CONFIG mismatch crashes).
12
+ """
13
+
14
+ import json
15
  import torch
16
+ import os, sys
17
+ import re
18
+ from tqdm import tqdm
19
+ from torch.utils.data import DataLoader, Subset
20
 
21
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
22
  from config import CONFIG
23
 
24
 
25
+ # ── Checkpoint loader ─────────────────────────────────────────────────
26
+
27
+ def _resolve_device(cfg_device: str) -> torch.device:
28
+ cfg_device = (cfg_device or "").lower()
29
+ if cfg_device == "cuda" and torch.cuda.is_available():
30
+ return torch.device("cuda")
31
+ if cfg_device == "mps" and torch.backends.mps.is_available():
32
+ return torch.device("mps")
33
+ if cfg_device in {"cpu", "cuda", "mps"}:
34
+ return torch.device("cpu")
35
+ if torch.cuda.is_available():
36
+ return torch.device("cuda")
37
+ if torch.backends.mps.is_available():
38
+ return torch.device("mps")
39
+ return torch.device("cpu")
40
+
41
+ def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
42
+ """
43
+ Auto-detect architecture from checkpoint weight shapes,
44
+ then load. Never fails due to CONFIG vs checkpoint mismatch.
45
+ """
46
+ import copy
47
+ from model.sanskrit_model import SanskritModel
48
+
49
+ cfg = copy.deepcopy(base_cfg)
50
+ state = torch.load(ckpt_path, map_location='cpu')
51
+
52
+ # d_model + vocab_size
53
+ ek = 'model.src_embed.token_emb.weight'
54
+ if ek in state:
55
+ vocab, d = state[ek].shape
56
+ cfg['model']['vocab_size'] = vocab
57
+ cfg['model']['d_model'] = d
58
+ cfg['model']['d_ff'] = d * 4
59
+
60
+ # n_layers
61
+ ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
62
+ if ids:
63
+ cfg['model']['n_layers'] = max(ids) + 1
64
+
65
+ # max_seq_len
66
+ pk = 'model.src_embed.pos_enc.pe'
67
+ if pk in state:
68
+ cfg['model']['max_seq_len'] = state[pk].shape[1]
69
+
70
+ # n_heads
71
+ d = cfg['model']['d_model']
72
+ h = cfg['model'].get('n_heads', 6)
73
+ if d % h != 0:
74
+ h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
75
+ cfg['model']['n_heads'] = h
76
+
77
+ print(f"🔍 Detected: d_model={cfg['model']['d_model']}, "
78
+ f"n_layers={cfg['model']['n_layers']}, "
79
+ f"max_seq_len={cfg['model']['max_seq_len']}, "
80
+ f"n_heads={cfg['model']['n_heads']}")
81
+
82
+ model = SanskritModel(cfg).to(device)
83
+ raw_state = torch.load(ckpt_path, map_location=device)
84
+ model_state = model.state_dict()
85
+ filtered_state = {}
86
+ skipped_mismatch = []
87
+ for k, v in raw_state.items():
88
+ if k in model_state and hasattr(v, "shape") and hasattr(model_state[k], "shape"):
89
+ if tuple(v.shape) != tuple(model_state[k].shape):
90
+ skipped_mismatch.append((k, tuple(v.shape), tuple(model_state[k].shape)))
91
+ continue
92
+ filtered_state[k] = v
93
+
94
+ missing, unexpected = model.load_state_dict(filtered_state, strict=False)
95
+
96
+ # hint_gate may be absent in older checkpoints — initialise safely
97
+ allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
98
+ real_missing = [k for k in missing if k not in allowed]
99
+ if real_missing:
100
+ print(f"⚠️ Missing keys: {real_missing[:3]} …")
101
+ if unexpected:
102
+ print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
103
+ if skipped_mismatch:
104
+ print(f"⚠️ Shape-mismatched keys skipped: {len(skipped_mismatch)}")
105
+
106
+ # Enable compact-attention branch only when checkpoint actually provides it.
107
+ has_compact = any(".compact_out_proj.weight" in k for k in filtered_state.keys())
108
+ if has_compact and hasattr(model, "model") and hasattr(model.model, "decoder_blocks"):
109
+ for block in model.model.decoder_blocks:
110
+ if hasattr(block, "cross_attn") and hasattr(block.cross_attn, "use_compact"):
111
+ block.cross_attn.use_compact = True
112
+ print("ℹ️ Compact cross-attention branch enabled from checkpoint.")
113
+ if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
114
+ with torch.no_grad():
115
+ w = model.model.hint_gate[0].weight
116
+ torch.nn.init.zeros_(model.model.hint_gate[0].bias)
117
+ torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
118
+ else torch.nn.init.xavier_uniform_(w)
119
+ print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
120
+
121
+ print("✅ Model loaded.")
122
+ return model, cfg
123
+
124
+
125
+ # ── Core inference function (same path as validation) ────────────────
126
+
127
+ @torch.no_grad()
128
+ def run_inference(model, input_ids, cfg):
129
+ """
130
+ Reverse diffusion sampling (clean path).
131
+ Uses cached reverse diffusion when available, otherwise model.generate().
132
+ """
133
+ inf = cfg['inference']
134
+ model.eval()
135
+ kwargs = dict(
136
+ num_steps=inf['num_steps'],
137
+ temperature=inf['temperature'],
138
+ top_k=inf['top_k'],
139
+ repetition_penalty=inf.get('repetition_penalty', 1.2),
140
+ diversity_penalty=inf.get('diversity_penalty', 0.0),
141
+ )
142
+ if hasattr(model, "generate_cached"):
143
+ out = model.generate_cached(input_ids, **kwargs)
144
+ else:
145
+ out = model.generate(input_ids, **kwargs)
146
+
147
+ # Optional retry with stronger anti-repetition settings.
148
+ if inf.get("auto_retry_on_repetition", True):
149
+ repeat_threshold = float(inf.get("repeat_ratio_threshold", 0.40))
150
+ max_repeat_run = int(inf.get("max_repeat_run", 4))
151
+ if _mean_repeat_ratio(out) >= repeat_threshold:
152
+ retry_kwargs = dict(kwargs)
153
+ retry_kwargs["temperature"] = max(0.6, float(kwargs["temperature"]) - 0.1)
154
+ retry_kwargs["top_k"] = max(20, int(kwargs["top_k"]) - 10)
155
+ retry_kwargs["repetition_penalty"] = max(float(kwargs["repetition_penalty"]), 1.6)
156
+ retry_kwargs["diversity_penalty"] = max(float(kwargs["diversity_penalty"]), 0.3)
157
+ if hasattr(model, "generate_cached"):
158
+ retry = model.generate_cached(input_ids, **retry_kwargs)
159
+ else:
160
+ retry = model.generate(input_ids, **retry_kwargs)
161
+ if _mean_repeat_ratio(retry) < _mean_repeat_ratio(out):
162
+ out = retry
163
+ out = _dedup_repeated_ids(out, max_repeat_run=max_repeat_run)
164
+
165
+ return out
166
+
167
+
168
+ def _mean_repeat_ratio(ids_tensor: torch.Tensor) -> float:
169
+ if ids_tensor is None or ids_tensor.numel() == 0:
170
+ return 0.0
171
+ ratios = []
172
+ for row in ids_tensor:
173
+ ids = [int(x) for x in row.tolist() if int(x) > 4]
174
+ if len(ids) < 2:
175
+ ratios.append(0.0)
176
+ continue
177
+ repeats = sum(1 for i in range(1, len(ids)) if ids[i] == ids[i - 1])
178
+ ratios.append(repeats / max(1, len(ids) - 1))
179
+ return float(sum(ratios) / max(1, len(ratios)))
180
+
181
+
182
+ def _dedup_repeated_ids(ids_tensor: torch.Tensor, max_repeat_run: int = 4) -> torch.Tensor:
183
+ """
184
+ Keep generation path unchanged, but clean extreme run-on token loops in final output ids.
185
+ """
186
+ if ids_tensor is None or ids_tensor.numel() == 0:
187
+ return ids_tensor
188
+ cleaned_rows = []
189
+ for row in ids_tensor.tolist():
190
+ out = []
191
+ prev = None
192
+ run = 0
193
+ for tok in row:
194
+ if tok <= 4:
195
+ out.append(tok)
196
+ prev = tok
197
+ run = 1
198
+ continue
199
+ if tok == prev:
200
+ run += 1
201
+ if run > max_repeat_run:
202
+ continue
203
+ else:
204
+ run = 1
205
+ out.append(tok)
206
+ prev = tok
207
+ # Preserve original length for downstream decode assumptions.
208
+ if len(out) < len(row):
209
+ out.extend([1] * (len(row) - len(out)))
210
+ else:
211
+ out = out[:len(row)]
212
+ cleaned_rows.append(out)
213
+ return torch.tensor(cleaned_rows, dtype=ids_tensor.dtype, device=ids_tensor.device)
214
+
215
+
216
+ def _decode_clean(tgt_tok, ids):
217
+ out = []
218
+ for x in ids:
219
+ if x in (1, 4) and out:
220
+ break
221
+ if x > 4:
222
+ out.append(x)
223
+ text = tgt_tok.decode(out).strip()
224
+ return _clean_repetition_text(text)
225
+
226
+
227
+ def _clean_repetition_text(text: str, max_repeat_run: int = 3) -> str:
228
+ words = [w for w in text.split() if w.strip()]
229
+ if not words:
230
+ return text.strip()
231
+ cleaned = []
232
+ prev = None
233
+ run = 0
234
+ for w in words:
235
+ if w == prev:
236
+ run += 1
237
+ if run > max_repeat_run:
238
+ continue
239
+ else:
240
+ run = 1
241
+ cleaned.append(w)
242
+ prev = w
243
+ return " ".join(cleaned).strip()
244
+
245
+
246
+ # ── Cleanup heuristics from UI inference pipeline ─────────────────────
247
+
248
+ _IAST_VOWELS = [
249
+ ("ai", "ऐ"), ("au", "औ"),
250
+ ("ā", "आ"), ("ī", "ई"), ("ū", "ऊ"),
251
+ ("ṛ", "ऋ"), ("ṝ", "ॠ"), ("ḷ", "ऌ"), ("ḹ", "ॡ"),
252
+ ("a", "अ"), ("i", "इ"), ("u", "उ"),
253
+ ("e", "ए"), ("o", "ओ"),
254
+ ]
255
+ _IAST_MATRAS = [
256
+ ("ai", "ै"), ("au", "ौ"),
257
+ ("ā", "ा"), ("ī", "ी"), ("ū", "ू"),
258
+ ("ṛ", "ृ"), ("ṝ", "ॄ"), ("ḷ", "ॢ"), ("ḹ", "ॣ"),
259
+ ("a", ""), ("i", "ि"), ("u", "ु"),
260
+ ("e", "े"), ("o", "ो"),
261
+ ]
262
+ _IAST_CONS = [
263
+ ("kṣ", "क��ष"), ("jñ", "ज्ञ"), ("tr", "त्र"),
264
+ ("kh", "ख"), ("gh", "घ"), ("ch", "छ"), ("jh", "झ"),
265
+ ("ṭh", "ठ"), ("ḍh", "ढ"), ("th", "थ"), ("dh", "ध"),
266
+ ("ph", "फ"), ("bh", "भ"),
267
+ ("ṅ", "ङ"), ("ñ", "ञ"), ("ṭ", "ट"), ("ḍ", "ड"),
268
+ ("ṇ", "ण"), ("ś", "श"), ("ṣ", "ष"), ("ḥ", "ः"),
269
+ ("ṃ", "ं"), ("ṁ", "ं"),
270
+ ("y", "य"), ("r", "र"), ("l", "ल"), ("v", "व"),
271
+ ("s", "स"), ("h", "ह"),
272
+ ("k", "क"), ("g", "ग"), ("c", "च"), ("j", "ज"),
273
+ ("t", "त"), ("d", "द"), ("n", "न"),
274
+ ("p", "प"), ("b", "ब"), ("m", "म"),
275
+ ]
276
+ _PUNCT = {".": "।", "|": "।", "||": "॥", ",": ",", "?": "?", "!": "!"}
277
+
278
+
279
+ def _iast_to_deva(text: str) -> str:
280
+ s = (text or "").lower()
281
+ out = []
282
+ i = 0
283
+ pending_consonant = False
284
+
285
+ def _match_any(pairs, pos):
286
+ for k, v in pairs:
287
+ if s.startswith(k, pos):
288
+ return k, v
289
+ return None, None
290
+
291
+ while i < len(s):
292
+ if s[i].isspace():
293
+ pending_consonant = False
294
+ out.append(s[i])
295
+ i += 1
296
+ continue
297
+ if s[i:i+2] == "||":
298
+ pending_consonant = False
299
+ out.append(_PUNCT["||"])
300
+ i += 2
301
+ continue
302
+ if s[i] in _PUNCT:
303
+ pending_consonant = False
304
+ out.append(_PUNCT[s[i]])
305
+ i += 1
306
+ continue
307
+
308
+ v_key, v_deva = _match_any(_IAST_VOWELS, i)
309
+ if v_key:
310
+ if pending_consonant:
311
+ _, v_matra = _match_any(_IAST_MATRAS, i)
312
+ out[-1] = out[-1] + (v_matra or "")
313
+ pending_consonant = False
314
+ else:
315
+ out.append(v_deva)
316
+ i += len(v_key)
317
+ continue
318
+
319
+ c_key, c_deva = _match_any(_IAST_CONS, i)
320
+ if c_key:
321
+ if pending_consonant:
322
+ out[-1] = out[-1] + "्"
323
+ out.append(c_deva)
324
+ pending_consonant = True
325
+ i += len(c_key)
326
+ continue
327
+
328
+ out.append(s[i])
329
+ pending_consonant = False
330
+ i += 1
331
+
332
+ return "".join(out).strip()
333
 
334
 
335
+ def _compute_cer(pred: str, ref: str) -> float:
336
+ if pred == ref:
337
+ return 0.0
338
+ if not pred or not ref:
339
+ return 1.0
340
+ m, n = len(pred), len(ref)
341
+ dp = list(range(n + 1))
342
+ for i in range(1, m + 1):
343
+ prev = dp[0]
344
+ dp[0] = i
345
+ for j in range(1, n + 1):
346
+ temp = dp[j]
347
+ cost = 0 if pred[i - 1] == ref[j - 1] else 1
348
+ dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
349
+ prev = temp
350
+ return dp[n] / max(m, n)
351
+
352
+
353
+ def _cleanup_thresholds(temperature: float, top_k: int):
354
+ temp = float(temperature)
355
+ k = max(1, int(top_k))
356
+ t_norm = max(0.0, min((temp - 0.4) / 0.6, 1.0))
357
+ k_norm = max(0.0, min((k - 20) / 80.0, 1.0))
358
+ diversity = 0.6 * t_norm + 0.4 * k_norm
359
+ cer_threshold = 0.10 + 0.18 * diversity
360
+ deva_ratio_threshold = 0.60 - 0.20 * diversity
361
+ return cer_threshold, deva_ratio_threshold
362
+
363
+
364
+ def _decode_with_cleanup(tgt_tok, ids, src_text: str, inf_cfg: dict):
365
+ model_out = _decode_clean(tgt_tok, ids)
366
+ rule_out = _iast_to_deva(src_text.strip())
367
+ deva_chars = sum(1 for ch in model_out if "\u0900" <= ch <= "\u097F")
368
+ deva_ratio = deva_chars / max(1, len(model_out))
369
+ cer = _compute_cer(model_out, rule_out)
370
+ cer_thr, ratio_thr = _cleanup_thresholds(
371
+ inf_cfg.get("temperature", 0.8),
372
+ inf_cfg.get("top_k", 40),
373
+ )
374
+ if deva_ratio < ratio_thr or len(model_out) > 2.0 * max(1, len(rule_out)) or cer > cer_thr:
375
+ return rule_out
376
+ return model_out
377
+
378
+
379
+ # ── Interactive demo ──────────────────────────────────────────────────
380
+
381
+ def interactive_demo(checkpoint=None, single_text=None):
382
  from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
383
 
384
+ cfg = CONFIG
385
+ device = _resolve_device(cfg['training'].get('device', 'cpu'))
386
+
387
+ model_name = cfg['model_type']
388
+ has_neg = cfg['data']['include_negative_examples']
389
+ ckpt = checkpoint or f"results/{model_name}_neg_{has_neg}/best_model.pt"
390
+
391
+ if not os.path.exists(ckpt):
392
+ raise FileNotFoundError(f"No checkpoint at {ckpt} — train first.")
393
+
394
+ model, cfg = load_model(ckpt, cfg, device)
395
+ model.eval()
396
+
397
  src_tok = SanskritSourceTokenizer(
398
+ vocab_size=cfg['model'].get('src_vocab_size', 16000),
399
+ max_len=cfg['model']['max_seq_len'],
400
  )
401
  tgt_tok = SanskritTargetTokenizer(
402
+ vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
403
+ max_len=cfg['model']['max_seq_len'],
404
  )
 
405
 
406
+ print("\n" + "="*55)
407
+ print("Sanskrit D3PM Paraphrase — type verse, get paraphrase")
408
+ print("="*55 + "\n")
409
 
410
+ while True:
411
+ try:
412
+ text = (single_text if single_text is not None else input("INPUT > ")).strip()
413
+ except (EOFError, KeyboardInterrupt):
414
+ break
415
+ if not text or text.lower() in ('quit', 'exit', 'q'):
416
+ break
417
 
418
+ ids = torch.tensor(
419
+ [src_tok.encode(text)[:cfg['model']['max_seq_len']]],
420
+ dtype=torch.long, device=device
421
+ )
422
+ out = run_inference(model, ids, cfg)
423
+ cleaned = _decode_with_cleanup(tgt_tok, out[0].tolist(), text, cfg["inference"])
424
+ print(f"PARAPHRASE → {cleaned}\n")
425
+ if single_text is not None:
426
+ break
427
 
 
 
 
428
 
429
+ # ── Batch evaluation ──────────────────────────────────────────────────
 
 
430
 
431
+ def batch_evaluate(sample_size=500, checkpoint=None):
432
+ from data.dataset import OptimizedSanskritDataset
433
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
 
 
434
 
435
+ cfg = CONFIG
436
+ device = _resolve_device(cfg['training'].get('device', 'cpu'))
 
 
437
 
438
+ model_name = cfg['model_type']
439
+ has_neg = cfg['data']['include_negative_examples']
440
+ exp_dir = f"results/{model_name}_neg_{has_neg}"
441
+ ckpt = checkpoint or f"{exp_dir}/best_model.pt"
442
 
443
+ if not os.path.exists(ckpt):
444
+ raise FileNotFoundError(f"No checkpoint at {ckpt}")
 
 
 
445
 
446
+ model, cfg = load_model(ckpt, cfg, device)
447
+ model.eval()
 
 
 
 
448
 
449
+ src_tok = SanskritSourceTokenizer(
450
+ vocab_size=cfg['model'].get('src_vocab_size', 16000),
451
+ max_len=cfg['model']['max_seq_len'],
452
+ )
453
+ tgt_tok = SanskritTargetTokenizer(
454
+ vocab_size=cfg['model'].get('tgt_vocab_size', 16000),
455
+ max_len=cfg['model']['max_seq_len'],
456
+ )
457
 
458
+ def collate(batch):
459
+ return {
460
+ 'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
461
+ 'target_text': [b['target_text'] for b in batch],
462
+ 'input_text': [b['input_text'] for b in batch],
463
+ }
464
 
465
+ dataset = OptimizedSanskritDataset(
466
+ split='test',
467
+ max_len=cfg['model']['max_seq_len'],
468
+ cfg=cfg,
469
+ src_tokenizer=src_tok,
470
+ tgt_tokenizer=tgt_tok,
471
+ )
472
+ indices = list(range(min(sample_size, len(dataset))))
473
+ loader = DataLoader(
474
+ Subset(dataset, indices),
475
+ batch_size=cfg['training']['batch_size'],
476
+ shuffle=False, collate_fn=collate
477
+ )
478
 
479
+ all_preds, all_refs, all_inputs = [], [], []
480
+ print(f"⏳ Generating {len(indices)} paraphrases …")
481
 
482
+ for batch in tqdm(loader):
483
+ ids = batch['input_ids'].to(device)
484
+ out = run_inference(model, ids, cfg)
485
+ for i in range(out.size(0)):
486
+ all_preds.append(_decode_with_cleanup(
487
+ tgt_tok, out[i].tolist(), batch['input_text'][i], cfg["inference"]
488
+ ))
489
+ all_refs.append(batch['target_text'][i].strip())
490
+ all_inputs.append(batch['input_text'][i].strip())
491
 
492
+ # Metrics
493
+ bleu_score, bert_f1 = 0.0, 0.0
494
+ try:
495
+ from nltk.translate.bleu_score import corpus_bleu
496
+ bleu_score = corpus_bleu(
497
+ [[r.split()] for r in all_refs],
498
+ [p.split() for p in all_preds]
499
+ )
500
+ except Exception:
501
+ pass
502
 
503
+ try:
504
+ import evaluate as hf_eval
505
+ res = hf_eval.load('bertscore').compute(
506
+ predictions=all_preds, references=all_refs, lang='hi'
507
+ )
508
+ bert_f1 = sum(res['f1']) / len(res['f1'])
509
+ except Exception:
510
+ pass
511
 
512
+ # Save
513
+ out_path = f"{exp_dir}/evaluation_results.txt"
514
+ pred_path = f"{exp_dir}/evaluation_predictions.jsonl"
515
+ with open(out_path, 'w', encoding='utf-8') as f:
516
+ f.write(f"Model : {model_name}\n")
517
+ f.write(f"Negatives: {has_neg}\n")
518
+ f.write(f"Steps : {cfg['inference']['num_steps']}\n")
519
+ f.write(f"Temp : {cfg['inference']['temperature']}\n")
520
+ f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
521
+ f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
522
+ f.write(f"BLEU : {bleu_score:.4f}\n")
523
+ f.write(f"BERTScore: {bert_f1:.4f}\n\n")
524
+ f.write("=== SAMPLES ===\n")
525
+ for i in range(min(20, len(all_preds))):
526
+ f.write(f"IN : {all_inputs[i]}\n")
527
+ f.write(f"REF : {all_refs[i]}\n")
528
+ f.write(f"PRED: {all_preds[i]}\n")
529
+ f.write("-" * 60 + "\n")
530
 
531
+ with open(pred_path, 'w', encoding='utf-8') as f:
532
+ for src, ref, pred in zip(all_inputs, all_refs, all_preds):
533
+ row = {"input": src, "reference": ref, "prediction": pred}
534
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
 
535
 
536
+ print(f"\n✅ Results → {out_path}")
537
+ print(f"🗂️ Saved predictions → {pred_path}")
538
+ print(f"📊 BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
539
+ return all_preds, all_refs
540
 
 
541
 
542
+ if __name__ == '__main__':
543
+ import argparse
544
+ p = argparse.ArgumentParser()
545
+ p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
546
+ p.add_argument('--samples', type=int, default=500)
547
+ p.add_argument('--checkpoint', type=str, default=None)
548
+ p.add_argument('--text', type=str, default=None, help='Run one-shot demo input and exit')
549
+ args = p.parse_args()
550
 
551
+ if args.mode == 'demo':
552
+ interactive_demo(checkpoint=args.checkpoint, single_text=args.text)
553
+ else:
554
+ batch_evaluate(args.samples, checkpoint=args.checkpoint)