bhsinghgrid commited on
Commit
f8437ec
·
verified ·
1 Parent(s): 96d6f92

Upload 27 files

Browse files
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from datetime import datetime
5
+
6
+ import gradio as gr
7
+ import torch
8
+
9
+ from config import CONFIG
10
+ from inference import load_model, run_inference, _build_tokenizers, _resolve_device
11
+
12
+
13
+ RESULTS_DIR = "generated_results"
14
+ os.makedirs(RESULTS_DIR, exist_ok=True)
15
+
16
+
17
+ def discover_checkpoints():
18
+ found = []
19
+ for root in ("ablation_results", "results7", "results"):
20
+ if not os.path.isdir(root):
21
+ continue
22
+ for entry in sorted(os.listdir(root)):
23
+ ckpt = os.path.join(root, entry, "best_model.pt")
24
+ if not os.path.exists(ckpt):
25
+ continue
26
+ found.append({
27
+ "label": f"{entry} [{root}]",
28
+ "path": ckpt,
29
+ "experiment": entry,
30
+ "root": root,
31
+ })
32
+ return found
33
+
34
+
35
+ def default_checkpoint_label():
36
+ checkpoints = discover_checkpoints()
37
+ if not checkpoints:
38
+ return None
39
+ for item in checkpoints:
40
+ if item["path"].endswith("ablation_results/T4/best_model.pt"):
41
+ return item["label"]
42
+ return checkpoints[0]["label"]
43
+
44
+
45
+ def checkpoint_map():
46
+ return {item["label"]: item for item in discover_checkpoints()}
47
+
48
+
49
+ def infer_model_type(experiment_name: str, root: str = "") -> str:
50
+ if root == "ablation_results":
51
+ return "d3pm_cross_attention"
52
+ if experiment_name.startswith("d3pm_cross_attention"):
53
+ return "d3pm_cross_attention"
54
+ if experiment_name.startswith("d3pm_encoder_decoder"):
55
+ return "d3pm_encoder_decoder"
56
+ if experiment_name.startswith("baseline_cross_attention"):
57
+ return "baseline_cross_attention"
58
+ if experiment_name.startswith("baseline_encoder_decoder"):
59
+ return "baseline_encoder_decoder"
60
+ return CONFIG["model_type"]
61
+
62
+
63
+ def infer_include_negative(experiment_name: str, root: str = "") -> bool:
64
+ if root == "ablation_results":
65
+ return False
66
+ if "_neg_True" in experiment_name:
67
+ return True
68
+ if "_neg_False" in experiment_name:
69
+ return False
70
+ return CONFIG["data"]["include_negative_examples"]
71
+
72
+
73
+ def build_runtime_cfg(ckpt_path: str):
74
+ experiment = os.path.basename(os.path.dirname(ckpt_path))
75
+ root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
76
+ cfg = copy.deepcopy(CONFIG)
77
+ cfg["model_type"] = infer_model_type(experiment, root=root)
78
+ cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
79
+ if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
80
+ t_val = int(experiment[1:])
81
+ cfg["model"]["diffusion_steps"] = t_val
82
+ cfg["inference"]["num_steps"] = t_val
83
+ device = _resolve_device(cfg)
84
+ return cfg, device, experiment
85
+
86
+
87
+ def load_selected_model(checkpoint_label):
88
+ mapping = checkpoint_map()
89
+ if checkpoint_label not in mapping:
90
+ raise gr.Error("Selected checkpoint was not found. Refresh the dropdown.")
91
+
92
+ ckpt_path = mapping[checkpoint_label]["path"]
93
+ cfg, device, experiment = build_runtime_cfg(ckpt_path)
94
+ model, cfg = load_model(ckpt_path, cfg, device)
95
+ src_tok, tgt_tok = _build_tokenizers(cfg)
96
+
97
+ bundle = {
98
+ "ckpt_path": ckpt_path,
99
+ "experiment": experiment,
100
+ "device": str(device),
101
+ "cfg": cfg,
102
+ "model": model,
103
+ "src_tok": src_tok,
104
+ "tgt_tok": tgt_tok,
105
+ }
106
+
107
+ model_info = {
108
+ "checkpoint": ckpt_path,
109
+ "experiment": experiment,
110
+ "model_type": cfg["model_type"],
111
+ "include_negatives": cfg["data"]["include_negative_examples"],
112
+ "device": str(device),
113
+ "max_seq_len": cfg["model"]["max_seq_len"],
114
+ "diffusion_steps": cfg["model"]["diffusion_steps"],
115
+ "d_model": cfg["model"]["d_model"],
116
+ "n_layers": cfg["model"]["n_layers"],
117
+ "n_heads": cfg["model"]["n_heads"],
118
+ }
119
+ status = f"Loaded `{experiment}` on `{device}`."
120
+ return bundle, status, model_info, cfg["inference"]["num_steps"]
121
+
122
+
123
+ def apply_preset(preset_name):
124
+ presets = {
125
+ "Manual": (0.70, 40, 1.20, 0.0, 64),
126
+ "Literal": (0.60, 20, 1.25, 0.0, 64),
127
+ "Balanced": (0.70, 40, 1.20, 0.0, 64),
128
+ "Creative": (0.85, 80, 1.20, 0.2, 64),
129
+ }
130
+ return presets.get(preset_name, presets["Balanced"])
131
+
132
+
133
+ def task_notes_md():
134
+ return """
135
+ ### Task Notes
136
+
137
+ **Task 1: KV Cache**
138
+ - Benchmark encoder caching vs standard generation.
139
+ - Best for engineering evaluation, not language quality evaluation.
140
+
141
+ **Task 2: Attention + Drift**
142
+ - Shows internal attention maps and output stabilization over diffusion steps.
143
+ - Useful for diagnostics and mentor discussion of model behavior.
144
+
145
+ **Task 3: Concept Vectors**
146
+ - Experimental PCA steering over decoder hidden states.
147
+ - Current outputs are exploratory, not strong semantic evidence yet.
148
+
149
+ **Task 4: Step Ablation**
150
+ - Requires retraining separate checkpoints for each diffusion step count.
151
+ - Use this UI for generation only; ablation analysis runs from `analysis/run_analysis.py`.
152
+
153
+ **Task 5: Quality Guidance**
154
+ - Advanced experimental feature in the analysis pipeline.
155
+ - Not exposed in this UI because the current evidence is still under validation.
156
+ """
157
+
158
+
159
+ def save_generation(experiment, record):
160
+ ts = datetime.now().strftime("%Y%m%d")
161
+ path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
162
+ existing = []
163
+ if os.path.exists(path):
164
+ with open(path, "r", encoding="utf-8") as f:
165
+ existing = json.load(f)
166
+ existing.append(record)
167
+ with open(path, "w", encoding="utf-8") as f:
168
+ json.dump(existing, f, ensure_ascii=False, indent=2)
169
+ return path
170
+
171
+
172
+ def clean_generated_text(text: str, max_consecutive: int = 2, max_occurrence_ratio: float = 0.15) -> str:
173
+ """
174
+ Lightweight cleanup for repetitive diffusion outputs.
175
+ Keeps Sanskrit tokens but trims pathological token loops.
176
+ """
177
+ text = " ".join(text.split())
178
+ if not text:
179
+ return text
180
+
181
+ tokens = text.split()
182
+ cleaned = []
183
+
184
+ # 1) Limit consecutive token repetitions.
185
+ prev = None
186
+ run = 0
187
+ for tok in tokens:
188
+ if tok == prev:
189
+ run += 1
190
+ else:
191
+ prev = tok
192
+ run = 1
193
+ if run <= max_consecutive:
194
+ cleaned.append(tok)
195
+
196
+ # 2) Limit global over-dominant tokens (common in collapse cases).
197
+ if cleaned:
198
+ max_occ = max(3, int(len(cleaned) * max_occurrence_ratio))
199
+ counts = {}
200
+ filtered = []
201
+ for tok in cleaned:
202
+ c = counts.get(tok, 0) + 1
203
+ counts[tok] = c
204
+ if c <= max_occ:
205
+ filtered.append(tok)
206
+ cleaned = filtered
207
+
208
+ out = " ".join(cleaned)
209
+ out = out.replace(" ।", "।").replace(" ॥", "॥")
210
+ out = " ".join(out.split())
211
+ return out
212
+
213
+
214
+ def generate_from_ui(
215
+ model_bundle,
216
+ input_text,
217
+ temperature,
218
+ top_k,
219
+ repetition_penalty,
220
+ diversity_penalty,
221
+ num_steps,
222
+ clean_output,
223
+ ):
224
+ if not model_bundle:
225
+ raise gr.Error("Load a model first.")
226
+ if not input_text.strip():
227
+ raise gr.Error("Enter input text first.")
228
+
229
+ cfg = copy.deepcopy(model_bundle["cfg"])
230
+ cfg["inference"]["temperature"] = float(temperature)
231
+ cfg["inference"]["top_k"] = int(top_k)
232
+ cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
233
+ cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
234
+ cfg["inference"]["num_steps"] = int(num_steps)
235
+
236
+ src_tok = model_bundle["src_tok"]
237
+ tgt_tok = model_bundle["tgt_tok"]
238
+ device = torch.device(model_bundle["device"])
239
+
240
+ input_ids = torch.tensor(
241
+ [src_tok.encode(input_text.strip())],
242
+ dtype=torch.long,
243
+ device=device,
244
+ )
245
+ out = run_inference(model_bundle["model"], input_ids, cfg)
246
+ clean = [x for x in out[0].tolist() if x > 4]
247
+ raw_output_text = tgt_tok.decode(clean).strip()
248
+ output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
249
+ if not output_text:
250
+ output_text = "(empty output)"
251
+
252
+ record = {
253
+ "timestamp": datetime.now().isoformat(timespec="seconds"),
254
+ "experiment": model_bundle["experiment"],
255
+ "checkpoint": model_bundle["ckpt_path"],
256
+ "input_text": input_text,
257
+ "raw_output_text": raw_output_text,
258
+ "output_text": output_text,
259
+ "clean_output": bool(clean_output),
260
+ "temperature": float(temperature),
261
+ "top_k": int(top_k),
262
+ "repetition_penalty": float(repetition_penalty),
263
+ "diversity_penalty": float(diversity_penalty),
264
+ "num_steps": int(num_steps),
265
+ }
266
+ log_path = save_generation(model_bundle["experiment"], record)
267
+ status = f"Generated with `{model_bundle['experiment']}`. Saved to `{log_path}`."
268
+ return output_text, status, record
269
+
270
+
271
+ with gr.Blocks(title="Sanskrit D3PM Studio") as demo:
272
+ model_state = gr.State(None)
273
+
274
+ gr.Markdown(
275
+ """
276
+ # Sanskrit D3PM Studio
277
+
278
+ Load any available checkpoint, generate Devanagari output from Roman/IAST Sanskrit,
279
+ and inspect the settings used for evaluation or demos.
280
+ """
281
+ )
282
+
283
+ with gr.Row():
284
+ with gr.Column(scale=2):
285
+ checkpoint_dropdown = gr.Dropdown(
286
+ label="Available Checkpoints",
287
+ choices=list(checkpoint_map().keys()),
288
+ value=default_checkpoint_label(),
289
+ interactive=True,
290
+ )
291
+ with gr.Column(scale=1):
292
+ refresh_btn = gr.Button("Refresh List")
293
+ load_btn = gr.Button("Load Model", variant="primary")
294
+
295
+ load_status = gr.Markdown("Select a checkpoint and load it.")
296
+ model_info = gr.JSON(label="Loaded Model Info")
297
+
298
+ with gr.Row():
299
+ with gr.Column(scale=2):
300
+ input_text = gr.Textbox(
301
+ label="Input Text (Roman / IAST Sanskrit)",
302
+ placeholder="dharmo rakṣati rakṣitaḥ",
303
+ lines=4,
304
+ )
305
+ output_text = gr.Textbox(
306
+ label="Generated Output (Devanagari)",
307
+ lines=6,
308
+ interactive=False,
309
+ )
310
+ generate_btn = gr.Button("Generate", variant="primary")
311
+ with gr.Column(scale=1):
312
+ preset = gr.Radio(
313
+ ["Manual", "Literal", "Balanced", "Creative"],
314
+ value="Balanced",
315
+ label="Inference Preset",
316
+ )
317
+ temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
318
+ top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
319
+ repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
320
+ diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
321
+ num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
322
+ clean_output = gr.Checkbox(value=True, label="Clean Output (dedupe loops)")
323
+
324
+ run_status = gr.Markdown("")
325
+ run_record = gr.JSON(label="Last Generation Metadata")
326
+
327
+ with gr.Accordion("Task Details and Evaluation Notes", open=False):
328
+ task_notes = gr.Markdown(task_notes_md())
329
+
330
+ gr.Examples(
331
+ examples=[
332
+ ["dharmo rakṣati rakṣitaḥ"],
333
+ ["satyameva jayate"],
334
+ ["ahaṃ brahmāsmi"],
335
+ ["yatra nāryastu pūjyante"],
336
+ ],
337
+ inputs=[input_text],
338
+ label="Quick Examples",
339
+ )
340
+
341
+ def refresh_checkpoints():
342
+ choices = list(checkpoint_map().keys())
343
+ value = choices[0] if choices else None
344
+ return gr.Dropdown(choices=choices, value=value)
345
+
346
+ refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
347
+ load_btn.click(
348
+ fn=load_selected_model,
349
+ inputs=[checkpoint_dropdown],
350
+ outputs=[model_state, load_status, model_info, num_steps],
351
+ )
352
+ preset.change(
353
+ fn=apply_preset,
354
+ inputs=[preset],
355
+ outputs=[temperature, top_k, repetition_penalty, diversity_penalty, num_steps],
356
+ )
357
+ generate_btn.click(
358
+ fn=generate_from_ui,
359
+ inputs=[
360
+ model_state,
361
+ input_text,
362
+ temperature,
363
+ top_k,
364
+ repetition_penalty,
365
+ diversity_penalty,
366
+ num_steps,
367
+ clean_output,
368
+ ],
369
+ outputs=[output_text, run_status, run_record],
370
+ )
371
+ input_text.submit(
372
+ fn=generate_from_ui,
373
+ inputs=[
374
+ model_state,
375
+ input_text,
376
+ temperature,
377
+ top_k,
378
+ repetition_penalty,
379
+ diversity_penalty,
380
+ num_steps,
381
+ clean_output,
382
+ ],
383
+ outputs=[output_text, run_status, run_record],
384
+ )
385
+
386
+
387
+ if __name__ == "__main__":
388
+ port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
389
+ demo.launch(server_name="127.0.0.1", server_port=port, share=False)
attention_viz.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/attention_viz.py
3
+ ==========================
4
+ Task 2: Attention weight capture and visualization across diffusion steps.
5
+
6
+ How it works (no retraining needed):
7
+ MultiHeadAttention now has two attributes:
8
+ - capture_weights: bool — set True to start storing weights
9
+ - last_attn_weights: Tensor — [B, n_heads, Lq, Lk], updated each forward call
10
+
11
+ AttentionCapture:
12
+ - Sets capture_weights=True on all cross-attention layers
13
+ - Hooks into generate_cached() to record weights at every diffusion step
14
+ - Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]}
15
+
16
+ Visualization:
17
+ - plot_attn_heatmap(): shows src→tgt alignment at a single step
18
+ - plot_attn_evolution(): shows how one src→tgt pair evolves over T steps
19
+ - plot_all_layers(): grid of heatmaps per layer at a given step
20
+
21
+ Usage:
22
+ from analysis.attention_viz import AttentionCapture, plot_attn_heatmap
23
+
24
+ capturer = AttentionCapture(model)
25
+ weights = capturer.capture(src_ids, src_tokens, tgt_tokens)
26
+ plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...)
27
+ """
28
+
29
+ import torch
30
+ import numpy as np
31
+ import os
32
+ from typing import List, Dict, Optional
33
+
34
+
35
+ # ── Attention capture ─────────────────────────────────────────────────
36
+
37
+ class AttentionCapture:
38
+ """
39
+ Captures cross-attention weights from all decoder layers at every
40
+ diffusion step during generate_cached().
41
+
42
+ Works by:
43
+ 1. Setting capture_weights=True on each DecoderBlock.cross_attn
44
+ 2. Running generate_cached() (encoder runs once via KV cache)
45
+ 3. After each denoising step, reading last_attn_weights from each layer
46
+ 4. Storing as {t_val: list_of_layer_weights}
47
+
48
+ Zero retraining required — uses the flag added to MultiHeadAttention.
49
+ """
50
+
51
+ def __init__(self, model):
52
+ """
53
+ Args:
54
+ model : SanskritModel wrapper (must be D3PMCrossAttention)
55
+ """
56
+ self.model = model
57
+ self.inner = model.model # D3PMCrossAttention
58
+ self._cross_attns = []
59
+
60
+ # Collect all cross-attention modules from decoder blocks
61
+ if hasattr(self.inner, 'decoder_blocks'):
62
+ for block in self.inner.decoder_blocks:
63
+ if hasattr(block, 'cross_attn'):
64
+ self._cross_attns.append(block.cross_attn)
65
+
66
+ if not self._cross_attns:
67
+ raise ValueError(
68
+ "No cross-attention layers found. "
69
+ "AttentionCapture only works with D3PMCrossAttention."
70
+ )
71
+
72
+ print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.")
73
+
74
+ def _enable(self):
75
+ """Turn on weight capture for all cross-attention layers."""
76
+ for ca in self._cross_attns:
77
+ ca.capture_weights = True
78
+
79
+ def _disable(self):
80
+ """Turn off weight capture (restores zero overhead)."""
81
+ for ca in self._cross_attns:
82
+ ca.capture_weights = False
83
+ ca.last_attn_weights = None
84
+
85
+ def _read_weights(self) -> List[np.ndarray]:
86
+ """
87
+ Read current last_attn_weights from all layers.
88
+ Returns list of [B, n_heads, Lq, Lk] arrays — one per layer.
89
+ Averages over heads to produce [B, Lq, Lk].
90
+ """
91
+ weights = []
92
+ for ca in self._cross_attns:
93
+ if ca.last_attn_weights is not None:
94
+ # Average over attention heads → [B, Lq, Lk]
95
+ w = ca.last_attn_weights.float().mean(dim=1)
96
+ weights.append(w.numpy())
97
+ return weights
98
+
99
+ @torch.no_grad()
100
+ def capture(
101
+ self,
102
+ src: torch.Tensor,
103
+ capture_every: int = 10,
104
+ ) -> Dict[int, List[np.ndarray]]:
105
+ """
106
+ Run full generation while capturing attention at every `capture_every` steps.
107
+
108
+ Args:
109
+ src : [1, src_len] or [B, src_len] IAST token ids
110
+ capture_every : capture weights every N steps (default 10)
111
+ Use 1 to capture every step (slow, high memory).
112
+
113
+ Returns:
114
+ step_weights : dict mapping t_val → list of [B, Lq, Lk] arrays
115
+ one array per decoder layer
116
+ keys are t values: T-1, T-1-N, ..., 0
117
+
118
+ Example:
119
+ weights = capturer.capture(src_ids, capture_every=10)
120
+ # weights[127] = layer weights at t=127 (heavy noise)
121
+ # weights[0] = layer weights at t=0 (clean output)
122
+ """
123
+ if src.dim() == 1:
124
+ src = src.unsqueeze(0)
125
+
126
+ inner = self.inner
127
+ T = inner.scheduler.num_timesteps
128
+ device = src.device
129
+
130
+ # KV cache: encode source once
131
+ memory, src_pad_mask = inner.encode_source(src)
132
+
133
+ B = src.shape[0]
134
+ tgt_len = inner.max_seq_len
135
+ mask_id = inner.mask_token_id
136
+
137
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
138
+ hint = None
139
+
140
+ step_weights: Dict[int, List[np.ndarray]] = {}
141
+
142
+ self._enable()
143
+ try:
144
+ inner.eval()
145
+ for t_val in range(T - 1, -1, -1):
146
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
147
+ is_last = (t_val == 0)
148
+
149
+ logits, _ = inner.forward_cached(
150
+ memory, src_pad_mask, x0_est, t,
151
+ x0_hint=hint, inference_mode=True,
152
+ )
153
+
154
+ # Capture at this step if scheduled or it's the last step
155
+ if (T - 1 - t_val) % capture_every == 0 or is_last:
156
+ step_weights[t_val] = self._read_weights()
157
+
158
+ import torch.nn.functional as F
159
+ probs = F.softmax(logits / 0.8, dim=-1)
160
+ x0_est = torch.argmax(probs, dim=-1) if is_last else \
161
+ _multinomial_sample(probs)
162
+ hint = x0_est
163
+
164
+ finally:
165
+ self._disable() # always restore — even if exception raised
166
+
167
+ print(f"Captured attention at {len(step_weights)} steps "
168
+ f"({len(self._cross_attns)} layers each).")
169
+ return step_weights
170
+
171
+
172
+ def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor:
173
+ B, L, V = probs.shape
174
+ flat = probs.view(B * L, V).clamp(min=1e-9)
175
+ flat = flat / flat.sum(dim=-1, keepdim=True)
176
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
177
+
178
+
179
+ # ── Visualization ─────────────────────────────────────────────────────
180
+
181
+ def plot_attn_heatmap(
182
+ step_weights: Dict[int, List[np.ndarray]],
183
+ t_val: int,
184
+ layer: int,
185
+ src_tokens: List[str],
186
+ tgt_tokens: List[str],
187
+ sample_idx: int = 0,
188
+ save_path: Optional[str] = None,
189
+ title: Optional[str] = None,
190
+ ):
191
+ """
192
+ Plot cross-attention heatmap for a single step and layer.
193
+
194
+ X-axis = source (IAST) tokens
195
+ Y-axis = target (Devanagari) positions
196
+ Color = attention weight (brighter = stronger attention)
197
+
198
+ Args:
199
+ step_weights : output of AttentionCapture.capture()
200
+ t_val : which diffusion step to visualize
201
+ layer : which decoder layer (0 = first, -1 = last)
202
+ src_tokens : list of IAST token strings for x-axis labels
203
+ tgt_tokens : list of Devanagari token strings for y-axis labels
204
+ sample_idx : which batch item to visualize (default 0)
205
+ save_path : if given, save figure to this path
206
+ title : custom plot title
207
+ """
208
+ try:
209
+ import matplotlib.pyplot as plt
210
+ import matplotlib.ticker as ticker
211
+ except ImportError:
212
+ print("pip install matplotlib to use visualization functions.")
213
+ return
214
+
215
+ if t_val not in step_weights:
216
+ available = sorted(step_weights.keys())
217
+ raise ValueError(
218
+ f"t_val={t_val} not in captured steps. "
219
+ f"Available: {available[:5]}...{available[-5:]}"
220
+ )
221
+
222
+ layers = step_weights[t_val]
223
+ weights = layers[layer][sample_idx] # [Lq, Lk]
224
+
225
+ # Trim to actual token lengths
226
+ n_src = min(len(src_tokens), weights.shape[1])
227
+ n_tgt = min(len(tgt_tokens), weights.shape[0])
228
+ weights = weights[:n_tgt, :n_src]
229
+
230
+ fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35)))
231
+ im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest')
232
+
233
+ ax.set_xticks(range(n_src))
234
+ ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9)
235
+ ax.set_yticks(range(n_tgt))
236
+ ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9)
237
+
238
+ ax.set_xlabel("Source (IAST)", fontsize=11)
239
+ ax.set_ylabel("Target position (Devanagari)", fontsize=11)
240
+
241
+ plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}"
242
+ ax.set_title(plot_title, fontsize=12, pad=10)
243
+
244
+ plt.colorbar(im, ax=ax, label="Attention weight")
245
+ plt.tight_layout()
246
+
247
+ if save_path:
248
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
249
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
250
+ print(f"Saved: {save_path}")
251
+ else:
252
+ plt.show()
253
+ plt.close()
254
+
255
+
256
+ def plot_attn_evolution(
257
+ step_weights: Dict[int, List[np.ndarray]],
258
+ src_token_idx: int,
259
+ tgt_token_idx: int,
260
+ layer: int = -1,
261
+ sample_idx: int = 0,
262
+ src_token_str: str = "",
263
+ tgt_token_str: str = "",
264
+ save_path: Optional[str] = None,
265
+ ):
266
+ """
267
+ Plot how attention between one specific src↔tgt token pair evolves
268
+ across all captured diffusion steps (T → 0).
269
+
270
+ Reveals whether a token pair is 'locked' (stable from early steps)
271
+ or 'flexible' (weight fluctuates until final steps).
272
+
273
+ Args:
274
+ step_weights : output of AttentionCapture.capture()
275
+ src_token_idx : index of source token to track
276
+ tgt_token_idx : index of target position to track
277
+ layer : decoder layer index
278
+ sample_idx : batch item
279
+ src_token_str : string label for the source token (for plot title)
280
+ tgt_token_str : string label for the target token (for plot title)
281
+ save_path : if given, save figure to this path
282
+ """
283
+ try:
284
+ import matplotlib.pyplot as plt
285
+ except ImportError:
286
+ print("pip install matplotlib to use visualization functions.")
287
+ return
288
+
289
+ t_vals = sorted(step_weights.keys(), reverse=True) # T-1 → 0
290
+ weights = []
291
+
292
+ for t_val in t_vals:
293
+ layers = step_weights[t_val]
294
+ w = layers[layer][sample_idx] # [Lq, Lk]
295
+ if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]:
296
+ weights.append(w[tgt_token_idx, src_token_idx])
297
+ else:
298
+ weights.append(0.0)
299
+
300
+ fig, ax = plt.subplots(figsize=(12, 4))
301
+ ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue')
302
+ ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue')
303
+
304
+ # Mark every 10th step on x-axis
305
+ step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else ""
306
+ for i, t in enumerate(t_vals)]
307
+ ax.set_xticks(range(len(t_vals)))
308
+ ax.set_xticklabels(step_labels, fontsize=8)
309
+ ax.set_xlabel("Diffusion step (T → 0)", fontsize=11)
310
+ ax.set_ylabel("Attention weight", fontsize=11)
311
+
312
+ pair_str = f"src[{src_token_idx}]={src_token_str!r} → tgt[{tgt_token_idx}]={tgt_token_str!r}"
313
+ ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11)
314
+ ax.set_xlim(0, len(t_vals) - 1)
315
+ ax.set_ylim(0, None)
316
+ plt.tight_layout()
317
+
318
+ if save_path:
319
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
320
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
321
+ print(f"Saved: {save_path}")
322
+ else:
323
+ plt.show()
324
+ plt.close()
325
+
326
+
327
+ def plot_all_layers(
328
+ step_weights: Dict[int, List[np.ndarray]],
329
+ t_val: int,
330
+ src_tokens: List[str],
331
+ tgt_tokens: List[str],
332
+ sample_idx: int = 0,
333
+ save_path: Optional[str] = None,
334
+ ):
335
+ """
336
+ Plot attention heatmaps for ALL decoder layers at a single diffusion step.
337
+ Shows how different layers specialize their attention patterns.
338
+ """
339
+ try:
340
+ import matplotlib.pyplot as plt
341
+ except ImportError:
342
+ print("pip install matplotlib to use visualization functions.")
343
+ return
344
+
345
+ layers = step_weights[t_val]
346
+ n_layers = len(layers)
347
+ n_cols = min(4, n_layers)
348
+ n_rows = (n_layers + n_cols - 1) // n_cols
349
+
350
+ fig, axes = plt.subplots(n_rows, n_cols,
351
+ figsize=(n_cols * 5, n_rows * 4))
352
+ axes = np.array(axes).flatten() if n_layers > 1 else [axes]
353
+
354
+ n_src = min(len(src_tokens), layers[0][sample_idx].shape[1])
355
+ n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0])
356
+
357
+ for i, (ax, layer_w) in enumerate(zip(axes, layers)):
358
+ w = layer_w[sample_idx][:n_tgt, :n_src]
359
+ im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest',
360
+ vmin=0, vmax=w.max())
361
+ ax.set_title(f"Layer {i}", fontsize=10)
362
+ ax.set_xticks(range(n_src))
363
+ ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7)
364
+ ax.set_yticks(range(n_tgt))
365
+ ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7)
366
+
367
+ for ax in axes[n_layers:]:
368
+ ax.set_visible(False)
369
+
370
+ fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02)
371
+ plt.tight_layout()
372
+
373
+ if save_path:
374
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
375
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
376
+ print(f"Saved: {save_path}")
377
+ else:
378
+ plt.show()
379
+ plt.close()
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1baa03568c2bed42621da115f6e6971411b59cc9dec6b58cf8f2ed87ba2e770
3
+ size 1077681643
concept_vectors.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/concept_vectors.py
3
+ ============================
4
+ Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
5
+
6
+ No retraining required. Uses decoder hidden states already computed
7
+ during generate_cached() — stored in model.model._last_hidden after
8
+ each forward_cached() call.
9
+
10
+ Steps:
11
+ 1. Collect hidden states from N examples at a fixed diffusion step
12
+ 2. Pool sequence dimension → [N, d_model] representation per example
13
+ 3. PCA → find principal directions in concept space
14
+ 4. Identify "diversity direction" (PC that best separates short/long outputs)
15
+ 5. Steer: at inference, shift hidden states along diversity direction
16
+ before the output head projection
17
+ 6. Generate at 5 points along the direction, measure output diversity
18
+
19
+ Key insight: the diversity direction is found purely from model outputs
20
+ (no human annotation needed). We use output length as a proxy:
21
+ short output → low diversity (model collapsed to simple token)
22
+ long output → high diversity (model exploring more of the space)
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+ from typing import List, Dict, Optional, Tuple
30
+
31
+
32
+ # ── Hidden state collection ───────────────────────────────────────────
33
+
34
+ @torch.no_grad()
35
+ def collect_hidden_states(
36
+ model,
37
+ src_list: List[torch.Tensor],
38
+ t_capture: int = 0,
39
+ temperature: float = 0.8,
40
+ top_k: int = 40,
41
+ max_samples: int = 1000,
42
+ ) -> Tuple[np.ndarray, List[str]]:
43
+ """
44
+ Run generate_cached() on a list of source tensors, collecting the
45
+ decoder hidden state at timestep t_capture for each sample.
46
+
47
+ Args:
48
+ model : SanskritModel (D3PMCrossAttention)
49
+ src_list : list of [1, src_len] tensors, one per sample
50
+ t_capture : which diffusion step to capture hidden states at
51
+ 0 = final (clean), T-1 = noisy start
52
+ temperature: sampling temperature
53
+ top_k : top-k filter
54
+ max_samples: cap at this many samples
55
+
56
+ Returns:
57
+ hidden_matrix : np.ndarray [N, d_model] — pooled hidden states
58
+ output_texts : list of N decoded output strings (for diversity analysis)
59
+ """
60
+ inner = model.model
61
+ T = inner.scheduler.num_timesteps
62
+ device = next(inner.parameters()).device
63
+
64
+ hidden_list = []
65
+ output_list = []
66
+
67
+ n = min(len(src_list), max_samples)
68
+ print(f"Collecting hidden states from {n} examples at t={t_capture}...")
69
+
70
+ for i, src in enumerate(src_list[:n]):
71
+ if i % 100 == 0:
72
+ print(f" {i}/{n}")
73
+
74
+ if src.dim() == 1:
75
+ src = src.unsqueeze(0)
76
+ src = src.to(device)
77
+
78
+ B = src.shape[0]
79
+ tgt_len = inner.max_seq_len
80
+ mask_id = inner.mask_token_id
81
+
82
+ # KV cache
83
+ memory, src_pad_mask = inner.encode_source(src)
84
+
85
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
86
+ hint = None
87
+ captured_hidden = None
88
+
89
+ for t_val in range(T - 1, -1, -1):
90
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
91
+ is_last = (t_val == 0)
92
+
93
+ logits, _ = inner.forward_cached(
94
+ memory, src_pad_mask, x0_est, t,
95
+ x0_hint=hint, inference_mode=True,
96
+ )
97
+
98
+ # Capture hidden state at target step
99
+ if t_val == t_capture and hasattr(inner, '_last_hidden'):
100
+ captured_hidden = inner._last_hidden.detach().cpu()
101
+
102
+ logits = logits / max(temperature, 1e-8)
103
+ if top_k > 0:
104
+ V = logits.shape[-1]
105
+ if top_k < V:
106
+ vals, _ = torch.topk(logits, top_k, dim=-1)
107
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
108
+
109
+ probs = F.softmax(logits, dim=-1)
110
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
111
+ hint = x0_est
112
+
113
+ # Pool hidden state over non-PAD positions → [d_model]
114
+ if captured_hidden is not None:
115
+ non_pad = (x0_est[0] > 1).cpu() # [tgt_len] bool
116
+ if non_pad.sum() > 0:
117
+ h = captured_hidden[0][non_pad].mean(dim=0) # [d_model]
118
+ else:
119
+ h = captured_hidden[0].mean(dim=0)
120
+ hidden_list.append(h.numpy())
121
+
122
+ # Decode output
123
+ ids = [x for x in x0_est[0].tolist() if x > 4]
124
+
125
+ print(f"Collected {len(hidden_list)} hidden states.")
126
+ return np.stack(hidden_list), output_list
127
+
128
+
129
+ # ── PCA on hidden states ──────────────────────────────────────────────
130
+
131
+ def fit_pca(
132
+ hidden_matrix: np.ndarray,
133
+ n_components: int = 50,
134
+ ) -> object:
135
+ """
136
+ Fit PCA on hidden state matrix.
137
+
138
+ Args:
139
+ hidden_matrix : [N, d_model]
140
+ n_components : number of PCA components to retain
141
+
142
+ Returns:
143
+ fitted sklearn PCA object
144
+ """
145
+ from sklearn.decomposition import PCA
146
+ n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
147
+ pca = PCA(n_components=n_comp)
148
+ pca.fit(hidden_matrix)
149
+ print(f"PCA fit: {n_comp} components explain "
150
+ f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.")
151
+ return pca
152
+
153
+
154
+ def find_diversity_direction(
155
+ hidden_matrix: np.ndarray,
156
+ output_lengths: List[int],
157
+ pca: object,
158
+ ) -> np.ndarray:
159
+ """
160
+ Find the PCA direction that best correlates with output diversity
161
+ (measured by output length as proxy).
162
+
163
+ Projects hidden states into PCA space, then finds the PC whose
164
+ scores have highest Spearman correlation with output lengths.
165
+
166
+ Returns:
167
+ direction : np.ndarray [d_model] — diversity direction in original space
168
+ """
169
+ from scipy.stats import spearmanr
170
+
171
+ projected = pca.transform(hidden_matrix) # [N, n_components]
172
+ lengths = np.array(output_lengths)
173
+
174
+ correlations = []
175
+ for pc_idx in range(projected.shape[1]):
176
+ r, _ = spearmanr(projected[:, pc_idx], lengths)
177
+ correlations.append(abs(r))
178
+
179
+ best_pc = int(np.argmax(correlations))
180
+ print(f"Diversity direction: PC {best_pc} "
181
+ f"(|r|={correlations[best_pc]:.3f} with output length)")
182
+
183
+ # Map back to original d_model space
184
+ direction = pca.components_[best_pc] # [d_model]
185
+ direction = direction / (np.linalg.norm(direction) + 1e-8)
186
+ return direction, best_pc, correlations[best_pc]
187
+
188
+
189
+ # ── Steered generation ────────────────────────────────────────────────
190
+
191
+ @torch.no_grad()
192
+ def generate_steered(
193
+ model,
194
+ src: torch.Tensor,
195
+ direction: np.ndarray,
196
+ alpha: float = 0.0,
197
+ temperature: float = 0.8,
198
+ top_k: int = 40,
199
+ ) -> torch.Tensor:
200
+ """
201
+ Generate output while steering hidden states along diversity direction.
202
+
203
+ At each diffusion step, after the decoder runs, we shift the hidden state
204
+ by alpha * direction before projecting to logits.
205
+
206
+ alpha > 0 → push toward high-diversity output
207
+ alpha < 0 → push toward low-diversity output
208
+ alpha = 0 → standard generation (no steering)
209
+
210
+ Args:
211
+ model : SanskritModel (D3PMCrossAttention)
212
+ src : [1, src_len] IAST token ids
213
+ direction : [d_model] diversity direction from find_diversity_direction()
214
+ alpha : steering strength
215
+ temperature / top_k: sampling params
216
+
217
+ Returns:
218
+ x0_est : [1, tgt_len] generated token ids
219
+ """
220
+ inner = model.model
221
+ T = inner.scheduler.num_timesteps
222
+ device = next(inner.parameters()).device
223
+
224
+ if src.dim() == 1:
225
+ src = src.unsqueeze(0)
226
+ src = src.to(device)
227
+
228
+ B = src.shape[0]
229
+ tgt_len = inner.max_seq_len
230
+ mask_id = inner.mask_token_id
231
+
232
+ dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device)
233
+
234
+ memory, src_pad_mask = inner.encode_source(src)
235
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
236
+ hint = None
237
+
238
+ inner.eval()
239
+ for t_val in range(T - 1, -1, -1):
240
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
241
+ is_last = (t_val == 0)
242
+
243
+ # Standard forward_cached but we intercept hidden states
244
+ PAD = 1
245
+ tgt_pad_mask = None # inference_mode
246
+
247
+ _, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \
248
+ (None, x0_est)
249
+ x = inner.tgt_embed(x_t_ids)
250
+ t_norm = t.float() / inner.scheduler.num_timesteps
251
+ t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
252
+ x = x + t_emb.unsqueeze(1)
253
+
254
+ if hint is not None:
255
+ hint_emb = inner.tgt_embed(hint)
256
+ gate = inner.hint_gate(x)
257
+ x = x + gate * hint_emb
258
+
259
+ for block in inner.decoder_blocks:
260
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
261
+
262
+ # ── STEERING: shift hidden states along diversity direction ───
263
+ if alpha != 0.0:
264
+ x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
265
+
266
+ # Project to logits using the head
267
+ logits = inner.head(x)
268
+
269
+ logits = logits / max(temperature, 1e-8)
270
+ if top_k > 0:
271
+ V = logits.shape[-1]
272
+ if top_k < V:
273
+ vals, _ = torch.topk(logits, top_k, dim=-1)
274
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
275
+
276
+ probs = F.softmax(logits, dim=-1)
277
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
278
+ hint = x0_est
279
+
280
+ return x0_est
281
+
282
+
283
+ def generate_diversity_spectrum(
284
+ model,
285
+ src: torch.Tensor,
286
+ direction: np.ndarray,
287
+ tgt_tokenizer,
288
+ alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0],
289
+ temperature: float = 0.8,
290
+ top_k: int = 40,
291
+ ) -> Dict[float, str]:
292
+ """
293
+ Generate outputs at 5 points along the diversity direction.
294
+
295
+ Args:
296
+ alphas : steering strengths (negative = low diversity, positive = high)
297
+
298
+ Returns:
299
+ dict mapping alpha → decoded Devanagari string
300
+ """
301
+ results = {}
302
+ for alpha in alphas:
303
+ out_ids = generate_steered(model, src, direction, alpha, temperature, top_k)
304
+ ids = [x for x in out_ids[0].tolist() if x > 4]
305
+ text = tgt_tokenizer.decode(ids).strip()
306
+ results[alpha] = text
307
+ print(f" alpha={alpha:+.1f} → {text}")
308
+ return results
309
+
310
+
311
+ def plot_pca_space(
312
+ hidden_matrix: np.ndarray,
313
+ output_lengths: List[int],
314
+ pca: object,
315
+ diversity_pc: int,
316
+ save_path: Optional[str] = None,
317
+ ):
318
+ """
319
+ Scatter plot of examples in PC1 vs PC2 space, coloured by output length.
320
+ Highlights the diversity direction.
321
+ """
322
+ try:
323
+ import matplotlib.pyplot as plt
324
+ except ImportError:
325
+ print("pip install matplotlib.")
326
+ return
327
+
328
+ projected = pca.transform(hidden_matrix) # [N, n_pc]
329
+ lengths = np.array(output_lengths)
330
+
331
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
332
+
333
+ # Left: PC0 vs PC1 coloured by length
334
+ ax = axes[0]
335
+ sc = ax.scatter(projected[:, 0], projected[:, 1],
336
+ c=lengths, cmap='viridis', alpha=0.6, s=15)
337
+ plt.colorbar(sc, ax=ax, label="Output length (chars)")
338
+ ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10)
339
+ ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10)
340
+ ax.set_title("Concept space (PC0 vs PC1)", fontsize=11)
341
+
342
+ # Right: explained variance
343
+ ax2 = axes[1]
344
+ cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
345
+ ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue')
346
+ ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}")
347
+ ax2.set_xlabel("Number of PCs", fontsize=10)
348
+ ax2.set_ylabel("Cumulative variance (%)", fontsize=10)
349
+ ax2.set_title("PCA explained variance", fontsize=11)
350
+ ax2.legend()
351
+ ax2.set_ylim(0, 102)
352
+
353
+ plt.tight_layout()
354
+ if save_path:
355
+ import os
356
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
357
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
358
+ print(f"Saved: {save_path}")
359
+ else:
360
+ plt.show()
361
+ plt.close()
362
+
363
+
364
+ def _sample(probs):
365
+ B, L, V = probs.shape
366
+ flat = probs.view(B * L, V).clamp(min=1e-9)
367
+ flat = flat / flat.sum(dim=-1, keepdim=True)
368
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
config_T16.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ablation config: T=16 diffusion steps
2
+ import os
3
+ import torch
4
+
5
+
6
+ def _get_env_int(name, default):
7
+ value = os.environ.get(name)
8
+ return int(value) if value is not None else default
9
+
10
+
11
+ def _get_env_str(name, default):
12
+ return os.environ.get(name, default)
13
+
14
+
15
+ # 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
16
+ MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
17
+ NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
18
+ DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
19
+ INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
20
+ TRAIN_DEVICE = _get_env_str(
21
+ "TRAIN_DEVICE",
22
+ "mps" if torch.backends.mps.is_available() else "cpu",
23
+ )
24
+
25
+ CONFIG = {
26
+ "model_type": MODEL,
27
+
28
+ "data": {
29
+ "include_negative_examples": NEGATIVES,
30
+ "dataset_size": 60000,
31
+ },
32
+
33
+ # "model": {
34
+ # "vocab_size": 16000,
35
+ # "max_seq_len": 80,
36
+ # "diffusion_steps": 10,
37
+ # "d_model": 384,
38
+ # "n_layers": 6,
39
+ # "n_heads": 6,
40
+ # "d_ff": 1536,
41
+ # "dropout": 0.15
42
+ # },
43
+ #
44
+ # "diffusion": {
45
+ # "mask_token_id": 0
46
+ # },
47
+ #
48
+ # "training": {
49
+ # "batch_size": 32,
50
+ # "epochs": 10,
51
+ # "lr": 2e-4,
52
+ # "label_smoothing": 0.05,
53
+ # "precision": "float32",
54
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
55
+ # "early_stopping_patience": 3
56
+ # }
57
+ # "model": {
58
+ # "vocab_size": 16000,
59
+ # "max_seq_len": 96, # Optimized for GRETIL slokas
60
+ # "diffusion_steps": 16, # Use 16 steps (better than 8)
61
+ # "d_model": 512, # Wider model learns faster
62
+ # "n_layers": 8,
63
+ # "n_heads": 8,
64
+ # "d_ff": 2048,
65
+ # "dropout": 0.1
66
+ # },
67
+ #
68
+ # "diffusion": {
69
+ # "mask_token_id": 0
70
+ # },
71
+ #
72
+ # "training": {
73
+ # "batch_size": 32,
74
+ # "epochs": 20, # 20 is enough with these tweaks
75
+ # "lr": 4e-4, # Higher LR + Warmup for speed
76
+ # "label_smoothing": 0.15, # Increased for 16k vocab stability
77
+ # "precision": "float32",
78
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
79
+ # "early_stopping_patience": 5
80
+ # }
81
+ 'diffusion': {
82
+ 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
83
+ },
84
+
85
+ # ── Model architecture ────────────────────────────────────────────
86
+ 'model': {
87
+ # 'vocab_size': 16000,
88
+ 'src_vocab_size': 16000, # Roman/IAST BPE vocab
89
+ 'tgt_vocab_size': 16000, # Devanagari BPE vocab
90
+ 'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
91
+ 'n_heads': 8, # 384 / 6 = 64 head_dim
92
+ 'd_ff': 4096, #2048, #1536, # 4 × d_model
93
+ 'n_layers': 8,#4,
94
+ 'dropout': 0.2,
95
+ 'max_seq_len': 80,
96
+ 'diffusion_steps': DIFFUSION_STEPS,
97
+ },
98
+
99
+ # ── Training ──────────────────────────────────────────────────────
100
+ 'training': {
101
+ 'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
102
+ 'batch_size': 32,
103
+ 'accum_steps': 2, # effective batch = 64
104
+ 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
105
+ 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
106
+ 'patience': 4, # early stop after 4 non-improving epochs
107
+ 'l1_lambda': 1e-7, # very light L1
108
+ 'device': TRAIN_DEVICE,
109
+ },
110
+
111
+ # ── Inference (used during val BERTScore and generate()) ──────────
112
+ 'inference': {
113
+ 'num_steps': INFERENCE_STEPS,
114
+ 'temperature': 0.7, # slightly lower = more confident output
115
+ 'top_k': 40,
116
+ 'repetition_penalty': 1.2,
117
+ 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
118
+ },
119
+ }
config_T32.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ablation config: T=32 diffusion steps
2
+ import os
3
+ import torch
4
+
5
+
6
+ def _get_env_int(name, default):
7
+ value = os.environ.get(name)
8
+ return int(value) if value is not None else default
9
+
10
+
11
+ def _get_env_str(name, default):
12
+ return os.environ.get(name, default)
13
+
14
+
15
+ # 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
16
+ MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
17
+ NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
18
+ DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
19
+ INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
20
+ TRAIN_DEVICE = _get_env_str(
21
+ "TRAIN_DEVICE",
22
+ "mps" if torch.backends.mps.is_available() else "cpu",
23
+ )
24
+
25
+ CONFIG = {
26
+ "model_type": MODEL,
27
+
28
+ "data": {
29
+ "include_negative_examples": NEGATIVES,
30
+ "dataset_size": 60000,
31
+ },
32
+
33
+ # "model": {
34
+ # "vocab_size": 16000,
35
+ # "max_seq_len": 80,
36
+ # "diffusion_steps": 10,
37
+ # "d_model": 384,
38
+ # "n_layers": 6,
39
+ # "n_heads": 6,
40
+ # "d_ff": 1536,
41
+ # "dropout": 0.15
42
+ # },
43
+ #
44
+ # "diffusion": {
45
+ # "mask_token_id": 0
46
+ # },
47
+ #
48
+ # "training": {
49
+ # "batch_size": 32,
50
+ # "epochs": 10,
51
+ # "lr": 2e-4,
52
+ # "label_smoothing": 0.05,
53
+ # "precision": "float32",
54
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
55
+ # "early_stopping_patience": 3
56
+ # }
57
+ # "model": {
58
+ # "vocab_size": 16000,
59
+ # "max_seq_len": 96, # Optimized for GRETIL slokas
60
+ # "diffusion_steps": 16, # Use 16 steps (better than 8)
61
+ # "d_model": 512, # Wider model learns faster
62
+ # "n_layers": 8,
63
+ # "n_heads": 8,
64
+ # "d_ff": 2048,
65
+ # "dropout": 0.1
66
+ # },
67
+ #
68
+ # "diffusion": {
69
+ # "mask_token_id": 0
70
+ # },
71
+ #
72
+ # "training": {
73
+ # "batch_size": 32,
74
+ # "epochs": 20, # 20 is enough with these tweaks
75
+ # "lr": 4e-4, # Higher LR + Warmup for speed
76
+ # "label_smoothing": 0.15, # Increased for 16k vocab stability
77
+ # "precision": "float32",
78
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
79
+ # "early_stopping_patience": 5
80
+ # }
81
+ 'diffusion': {
82
+ 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
83
+ },
84
+
85
+ # ── Model architecture ────────────────────────────────────────────
86
+ 'model': {
87
+ # 'vocab_size': 16000,
88
+ 'src_vocab_size': 16000, # Roman/IAST BPE vocab
89
+ 'tgt_vocab_size': 16000, # Devanagari BPE vocab
90
+ 'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
91
+ 'n_heads': 8, # 384 / 6 = 64 head_dim
92
+ 'd_ff': 4096, #2048, #1536, # 4 × d_model
93
+ 'n_layers': 8,#4,
94
+ 'dropout': 0.2,
95
+ 'max_seq_len': 80,
96
+ 'diffusion_steps': DIFFUSION_STEPS,
97
+ },
98
+
99
+ # ── Training ──────────────────────────────────────────────────────
100
+ 'training': {
101
+ 'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
102
+ 'batch_size': 32,
103
+ 'accum_steps': 2, # effective batch = 64
104
+ 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
105
+ 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
106
+ 'patience': 4, # early stop after 4 non-improving epochs
107
+ 'l1_lambda': 1e-7, # very light L1
108
+ 'device': TRAIN_DEVICE,
109
+ },
110
+
111
+ # ── Inference (used during val BERTScore and generate()) ──────────
112
+ 'inference': {
113
+ 'num_steps': INFERENCE_STEPS,
114
+ 'temperature': 0.7, # slightly lower = more confident output
115
+ 'top_k': 40,
116
+ 'repetition_penalty': 1.2,
117
+ 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
118
+ },
119
+ }
config_T4.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ablation config: T=4 diffusion steps
2
+ import os
3
+ import torch
4
+
5
+
6
+ def _get_env_int(name, default):
7
+ value = os.environ.get(name)
8
+ return int(value) if value is not None else default
9
+
10
+
11
+ def _get_env_str(name, default):
12
+ return os.environ.get(name, default)
13
+
14
+
15
+ # 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
16
+ MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
17
+ NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
18
+ DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
19
+ INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
20
+ TRAIN_DEVICE = _get_env_str(
21
+ "TRAIN_DEVICE",
22
+ "mps" if torch.backends.mps.is_available() else "cpu",
23
+ )
24
+
25
+ CONFIG = {
26
+ "model_type": MODEL,
27
+
28
+ "data": {
29
+ "include_negative_examples": NEGATIVES,
30
+ "dataset_size": 60000,
31
+ },
32
+
33
+ # "model": {
34
+ # "vocab_size": 16000,
35
+ # "max_seq_len": 80,
36
+ # "diffusion_steps": 10,
37
+ # "d_model": 384,
38
+ # "n_layers": 6,
39
+ # "n_heads": 6,
40
+ # "d_ff": 1536,
41
+ # "dropout": 0.15
42
+ # },
43
+ #
44
+ # "diffusion": {
45
+ # "mask_token_id": 0
46
+ # },
47
+ #
48
+ # "training": {
49
+ # "batch_size": 32,
50
+ # "epochs": 10,
51
+ # "lr": 2e-4,
52
+ # "label_smoothing": 0.05,
53
+ # "precision": "float32",
54
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
55
+ # "early_stopping_patience": 3
56
+ # }
57
+ # "model": {
58
+ # "vocab_size": 16000,
59
+ # "max_seq_len": 96, # Optimized for GRETIL slokas
60
+ # "diffusion_steps": 16, # Use 16 steps (better than 8)
61
+ # "d_model": 512, # Wider model learns faster
62
+ # "n_layers": 8,
63
+ # "n_heads": 8,
64
+ # "d_ff": 2048,
65
+ # "dropout": 0.1
66
+ # },
67
+ #
68
+ # "diffusion": {
69
+ # "mask_token_id": 0
70
+ # },
71
+ #
72
+ # "training": {
73
+ # "batch_size": 32,
74
+ # "epochs": 20, # 20 is enough with these tweaks
75
+ # "lr": 4e-4, # Higher LR + Warmup for speed
76
+ # "label_smoothing": 0.15, # Increased for 16k vocab stability
77
+ # "precision": "float32",
78
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
79
+ # "early_stopping_patience": 5
80
+ # }
81
+ 'diffusion': {
82
+ 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
83
+ },
84
+
85
+ # ── Model architecture ────────────────────────────────────────────
86
+ 'model': {
87
+ # 'vocab_size': 16000,
88
+ 'src_vocab_size': 16000, # Roman/IAST BPE vocab
89
+ 'tgt_vocab_size': 16000, # Devanagari BPE vocab
90
+ 'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
91
+ 'n_heads': 8, # 384 / 6 = 64 head_dim
92
+ 'd_ff': 4096, #2048, #1536, # 4 × d_model
93
+ 'n_layers': 8,#4,
94
+ 'dropout': 0.2,
95
+ 'max_seq_len': 80,
96
+ 'diffusion_steps': DIFFUSION_STEPS,
97
+ },
98
+
99
+ # ── Training ──────────────────────────────────────────────────────
100
+ 'training': {
101
+ 'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
102
+ 'batch_size': 32,
103
+ 'accum_steps': 2, # effective batch = 64
104
+ 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
105
+ 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
106
+ 'patience': 4, # early stop after 4 non-improving epochs
107
+ 'l1_lambda': 1e-7, # very light L1
108
+ 'device': TRAIN_DEVICE,
109
+ },
110
+
111
+ # ── Inference (used during val BERTScore and generate()) ──────────
112
+ 'inference': {
113
+ 'num_steps': INFERENCE_STEPS,
114
+ 'temperature': 0.7, # slightly lower = more confident output
115
+ 'top_k': 40,
116
+ 'repetition_penalty': 1.2,
117
+ 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
118
+ },
119
+ }
config_T64.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ablation config: T=64 diffusion steps
2
+ import os
3
+ import torch
4
+
5
+
6
+ def _get_env_int(name, default):
7
+ value = os.environ.get(name)
8
+ return int(value) if value is not None else default
9
+
10
+
11
+ def _get_env_str(name, default):
12
+ return os.environ.get(name, default)
13
+
14
+
15
+ # 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
16
+ MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
17
+ NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
18
+ DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
19
+ INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
20
+ TRAIN_DEVICE = _get_env_str(
21
+ "TRAIN_DEVICE",
22
+ "mps" if torch.backends.mps.is_available() else "cpu",
23
+ )
24
+
25
+ CONFIG = {
26
+ "model_type": MODEL,
27
+
28
+ "data": {
29
+ "include_negative_examples": NEGATIVES,
30
+ "dataset_size": 60000,
31
+ },
32
+
33
+ # "model": {
34
+ # "vocab_size": 16000,
35
+ # "max_seq_len": 80,
36
+ # "diffusion_steps": 10,
37
+ # "d_model": 384,
38
+ # "n_layers": 6,
39
+ # "n_heads": 6,
40
+ # "d_ff": 1536,
41
+ # "dropout": 0.15
42
+ # },
43
+ #
44
+ # "diffusion": {
45
+ # "mask_token_id": 0
46
+ # },
47
+ #
48
+ # "training": {
49
+ # "batch_size": 32,
50
+ # "epochs": 10,
51
+ # "lr": 2e-4,
52
+ # "label_smoothing": 0.05,
53
+ # "precision": "float32",
54
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
55
+ # "early_stopping_patience": 3
56
+ # }
57
+ # "model": {
58
+ # "vocab_size": 16000,
59
+ # "max_seq_len": 96, # Optimized for GRETIL slokas
60
+ # "diffusion_steps": 16, # Use 16 steps (better than 8)
61
+ # "d_model": 512, # Wider model learns faster
62
+ # "n_layers": 8,
63
+ # "n_heads": 8,
64
+ # "d_ff": 2048,
65
+ # "dropout": 0.1
66
+ # },
67
+ #
68
+ # "diffusion": {
69
+ # "mask_token_id": 0
70
+ # },
71
+ #
72
+ # "training": {
73
+ # "batch_size": 32,
74
+ # "epochs": 20, # 20 is enough with these tweaks
75
+ # "lr": 4e-4, # Higher LR + Warmup for speed
76
+ # "label_smoothing": 0.15, # Increased for 16k vocab stability
77
+ # "precision": "float32",
78
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
79
+ # "early_stopping_patience": 5
80
+ # }
81
+ 'diffusion': {
82
+ 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
83
+ },
84
+
85
+ # ── Model architecture ────────────────────────────────────────────
86
+ 'model': {
87
+ # 'vocab_size': 16000,
88
+ 'src_vocab_size': 16000, # Roman/IAST BPE vocab
89
+ 'tgt_vocab_size': 16000, # Devanagari BPE vocab
90
+ 'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
91
+ 'n_heads': 8, # 384 / 6 = 64 head_dim
92
+ 'd_ff': 4096, #2048, #1536, # 4 × d_model
93
+ 'n_layers': 8,#4,
94
+ 'dropout': 0.2,
95
+ 'max_seq_len': 80,
96
+ 'diffusion_steps': DIFFUSION_STEPS,
97
+ },
98
+
99
+ # ── Training ──────────────────────────────────────────────────────
100
+ 'training': {
101
+ 'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
102
+ 'batch_size': 32,
103
+ 'accum_steps': 2, # effective batch = 64
104
+ 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
105
+ 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
106
+ 'patience': 4, # early stop after 4 non-improving epochs
107
+ 'l1_lambda': 1e-7, # very light L1
108
+ 'device': TRAIN_DEVICE,
109
+ },
110
+
111
+ # ── Inference (used during val BERTScore and generate()) ──────────
112
+ 'inference': {
113
+ 'num_steps': INFERENCE_STEPS,
114
+ 'temperature': 0.7, # slightly lower = more confident output
115
+ 'top_k': 40,
116
+ 'repetition_penalty': 1.2,
117
+ 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
118
+ },
119
+ }
config_T8.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ablation config: T=8 diffusion steps
2
+ import os
3
+ import torch
4
+
5
+
6
+ def _get_env_int(name, default):
7
+ value = os.environ.get(name)
8
+ return int(value) if value is not None else default
9
+
10
+
11
+ def _get_env_str(name, default):
12
+ return os.environ.get(name, default)
13
+
14
+
15
+ # 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
16
+ MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
17
+ NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
18
+ DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
19
+ INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
20
+ TRAIN_DEVICE = _get_env_str(
21
+ "TRAIN_DEVICE",
22
+ "mps" if torch.backends.mps.is_available() else "cpu",
23
+ )
24
+
25
+ CONFIG = {
26
+ "model_type": MODEL,
27
+
28
+ "data": {
29
+ "include_negative_examples": NEGATIVES,
30
+ "dataset_size": 60000,
31
+ },
32
+
33
+ # "model": {
34
+ # "vocab_size": 16000,
35
+ # "max_seq_len": 80,
36
+ # "diffusion_steps": 10,
37
+ # "d_model": 384,
38
+ # "n_layers": 6,
39
+ # "n_heads": 6,
40
+ # "d_ff": 1536,
41
+ # "dropout": 0.15
42
+ # },
43
+ #
44
+ # "diffusion": {
45
+ # "mask_token_id": 0
46
+ # },
47
+ #
48
+ # "training": {
49
+ # "batch_size": 32,
50
+ # "epochs": 10,
51
+ # "lr": 2e-4,
52
+ # "label_smoothing": 0.05,
53
+ # "precision": "float32",
54
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
55
+ # "early_stopping_patience": 3
56
+ # }
57
+ # "model": {
58
+ # "vocab_size": 16000,
59
+ # "max_seq_len": 96, # Optimized for GRETIL slokas
60
+ # "diffusion_steps": 16, # Use 16 steps (better than 8)
61
+ # "d_model": 512, # Wider model learns faster
62
+ # "n_layers": 8,
63
+ # "n_heads": 8,
64
+ # "d_ff": 2048,
65
+ # "dropout": 0.1
66
+ # },
67
+ #
68
+ # "diffusion": {
69
+ # "mask_token_id": 0
70
+ # },
71
+ #
72
+ # "training": {
73
+ # "batch_size": 32,
74
+ # "epochs": 20, # 20 is enough with these tweaks
75
+ # "lr": 4e-4, # Higher LR + Warmup for speed
76
+ # "label_smoothing": 0.15, # Increased for 16k vocab stability
77
+ # "precision": "float32",
78
+ # "device": "mps" if torch.backends.mps.is_available() else "cpu",
79
+ # "early_stopping_patience": 5
80
+ # }
81
+ 'diffusion': {
82
+ 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
83
+ },
84
+
85
+ # ── Model architecture ────────────────────────────────────────────
86
+ 'model': {
87
+ # 'vocab_size': 16000,
88
+ 'src_vocab_size': 16000, # Roman/IAST BPE vocab
89
+ 'tgt_vocab_size': 16000, # Devanagari BPE vocab
90
+ 'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
91
+ 'n_heads': 8, # 384 / 6 = 64 head_dim
92
+ 'd_ff': 4096, #2048, #1536, # 4 × d_model
93
+ 'n_layers': 8,#4,
94
+ 'dropout': 0.2,
95
+ 'max_seq_len': 80,
96
+ 'diffusion_steps': DIFFUSION_STEPS,
97
+ },
98
+
99
+ # ── Training ──────────────────────────────────────────────────────
100
+ 'training': {
101
+ 'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
102
+ 'batch_size': 32,
103
+ 'accum_steps': 2, # effective batch = 64
104
+ 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
105
+ 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
106
+ 'patience': 4, # early stop after 4 non-improving epochs
107
+ 'l1_lambda': 1e-7, # very light L1
108
+ 'device': TRAIN_DEVICE,
109
+ },
110
+
111
+ # ── Inference (used during val BERTScore and generate()) ──────────
112
+ 'inference': {
113
+ 'num_steps': INFERENCE_STEPS,
114
+ 'temperature': 0.7, # slightly lower = more confident output
115
+ 'top_k': 40,
116
+ 'repetition_penalty': 1.2,
117
+ 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
118
+ },
119
+ }
d3pm_model_cross_attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
3
+ =================================================================
4
+ INPUT : quote_text tokens (Roman script, src_vocab_size)
5
+ OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
6
+
7
+ src_embed uses src_vocab_size (Roman BPE)
8
+ tgt_embed uses tgt_vocab_size (Devanagari BPE)
9
+ head outputs tgt_vocab_size (predict Devanagari tokens)
10
+ Weight tying: head <-> tgt_embed only (NOT src_embed)
11
+
12
+ Generation bugs fixed:
13
+ BUG 1 - tgt_pad_mask suppressed during inference
14
+ BUG 2 - q_sample skipped at t=0
15
+ BUG 3 - time embedding before hint_gate
16
+ BUG 4 - diversity penalty uses global mean not var
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusion.scheduler import OptimizedCosineScheduler
24
+ from diffusion.forward_process import AbsorbingForwardProcess
25
+
26
+
27
+ class SinusoidalPositionalEncoding(nn.Module):
28
+ def __init__(self, d_model, max_len=5000):
29
+ super().__init__()
30
+ pe = torch.zeros(max_len, d_model)
31
+ position = torch.arange(0, max_len).unsqueeze(1).float()
32
+ div_term = torch.exp(
33
+ torch.arange(0, d_model, 2).float() *
34
+ (-torch.log(torch.tensor(10000.0)) / d_model)
35
+ )
36
+ pe[:, 0::2] = torch.sin(position * div_term)
37
+ pe[:, 1::2] = torch.cos(position * div_term)
38
+ self.register_buffer("pe", pe.unsqueeze(0))
39
+
40
+ def forward(self, x):
41
+ return x + self.pe[:, :x.size(1), :]
42
+
43
+
44
+ class SanskritEmbeddings(nn.Module):
45
+ def __init__(self, vocab_size, d_model, max_seq_len):
46
+ super().__init__()
47
+ self.token_emb = nn.Embedding(vocab_size, d_model)
48
+ self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
49
+ self.token_embedding = self.token_emb
50
+ def forward(self, tokens):
51
+ return self.pos_enc(self.token_emb(tokens))
52
+
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, d_model, n_heads, dropout=0.1):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+ self.d_model = d_model
59
+ self.n_heads = n_heads
60
+ self.head_dim = d_model // n_heads
61
+ self.q_proj = nn.Linear(d_model, d_model)
62
+ self.k_proj = nn.Linear(d_model, d_model)
63
+ self.v_proj = nn.Linear(d_model, d_model)
64
+ self.out_proj = nn.Linear(d_model, d_model)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, q, k, v, mask=None):
68
+ B, Lq, _ = q.size()
69
+ Lk = k.size(1)
70
+ Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
71
+ K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
72
+ V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
73
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
74
+ if mask is not None:
75
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
76
+ attn = self.dropout(torch.softmax(scores, dim=-1))
77
+ out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
78
+ return self.out_proj(out)
79
+
80
+
81
+ class EncoderBlock(nn.Module):
82
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
83
+ super().__init__()
84
+ self.mha = MultiHeadAttention(d_model, n_heads, dropout)
85
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
86
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
87
+ self.norm1 = nn.LayerNorm(d_model)
88
+ self.norm2 = nn.LayerNorm(d_model)
89
+ def forward(self, x, pad_mask=None):
90
+ x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
91
+ return self.norm2(x + self.ff(x))
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
96
+ super().__init__()
97
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
98
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
99
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
100
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout))
101
+ self.norm1 = nn.LayerNorm(d_model)
102
+ self.norm2 = nn.LayerNorm(d_model)
103
+ self.norm3 = nn.LayerNorm(d_model)
104
+ def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
105
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
106
+ x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
107
+ return self.norm3(x + self.ff(x))
108
+
109
+
110
+ class D3PMCrossAttention(nn.Module):
111
+ def __init__(self, cfg):
112
+ super().__init__()
113
+ self.cfg = cfg
114
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
115
+ d = cfg['model']['d_model']
116
+ nhead = cfg['model']['n_heads']
117
+ d_ff = cfg['model']['d_ff']
118
+ drop = cfg['model']['dropout']
119
+ seqlen = cfg['model']['max_seq_len']
120
+ nlayer = cfg['model']['n_layers']
121
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
122
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
123
+
124
+ # Separate embeddings: Roman src, Devanagari tgt
125
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
126
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
127
+
128
+ self.scheduler = OptimizedCosineScheduler(cfg)
129
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
130
+
131
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
132
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
133
+
134
+ self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
135
+ self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
136
+
137
+ # Output head: predict Devanagari tokens, tied to tgt_embed
138
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
139
+ self.head.weight = self.tgt_embed.token_embedding.weight
140
+
141
+ def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
142
+ PAD = 1
143
+ src_pad_mask = (src == PAD)
144
+ # BUG 1 FIX: no tgt mask during inference
145
+ tgt_pad_mask = None if inference_mode else (tgt == PAD)
146
+
147
+ # Encode Roman source
148
+ memory = self.src_embed(src)
149
+ for block in self.encoder_blocks:
150
+ memory = block(memory, pad_mask=src_pad_mask)
151
+
152
+ # BUG 2 FIX: skip q_sample at final step t=0
153
+ if inference_mode and (t == 0).all():
154
+ x_t_ids = tgt
155
+ else:
156
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
157
+
158
+ x = self.tgt_embed(x_t_ids)
159
+
160
+ # BUG 3 FIX: time embedding BEFORE hint gate
161
+ t_norm = t.float() / self.scheduler.num_timesteps
162
+ t_emb = self.time_mlp(t_norm.unsqueeze(-1))
163
+ x = x + t_emb.unsqueeze(1)
164
+
165
+ if x0_hint is not None:
166
+ hint_emb = self.tgt_embed(x0_hint)
167
+ gate = self.hint_gate(x) # time-aware gate
168
+ x = x + gate * hint_emb
169
+
170
+ for block in self.decoder_blocks:
171
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
172
+
173
+ return self.head(x), None
174
+
175
+ @torch.no_grad()
176
+ def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
177
+ repetition_penalty=1.2, diversity_penalty=0.0):
178
+ if src.dim() == 1:
179
+ src = src.unsqueeze(0)
180
+ device = src.device
181
+ B, L = src.shape
182
+ T = self.scheduler.num_timesteps
183
+ steps = num_steps or T
184
+ step_size = max(1, T // steps)
185
+ timesteps = list(range(T - 1, -1, -step_size))
186
+ if timesteps[-1] != 0:
187
+ timesteps.append(0)
188
+
189
+ mask_id = self.mask_token_id
190
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
191
+ hint = None
192
+
193
+ self.eval()
194
+ with torch.no_grad():
195
+ for step_idx, t_val in enumerate(timesteps):
196
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
197
+ is_last = (step_idx == len(timesteps) - 1)
198
+ logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
199
+ if repetition_penalty != 1.0:
200
+ logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
201
+ if diversity_penalty > 0.0:
202
+ logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
203
+ logits = logits / max(temperature, 1e-5)
204
+ if top_k > 0:
205
+ logits = _top_k_filter(logits, top_k)
206
+ probs = F.softmax(logits, dim=-1)
207
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
208
+ hint = x0_est
209
+ return x0_est
210
+
211
+
212
+ class BaselineCrossAttention(nn.Module):
213
+ def __init__(self, cfg):
214
+ super().__init__()
215
+ d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
216
+ d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
217
+ seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
218
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
219
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
220
+ self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
221
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
222
+ self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
223
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
224
+ self.head = nn.Linear(d, tgt_vocab, bias=False)
225
+ self.head.weight = self.tgt_embed.token_embedding.weight
226
+
227
+ def forward(self, src, tgt, t=None, x0_hint=None):
228
+ PAD = 1
229
+ memory = self.src_embed(src)
230
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
231
+ x = self.tgt_embed(tgt)
232
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
233
+ return (self.head(x),)
234
+
235
+ @torch.no_grad()
236
+ def generate(self, src, max_len=None, start_token_id=2, **kwargs):
237
+ if max_len is None: max_len = src.size(1)
238
+ B, device = src.size(0), src.device
239
+ memory = self.src_embed(src)
240
+ for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
241
+ ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
242
+ for _ in range(max_len):
243
+ x = self.tgt_embed(ys)
244
+ for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
245
+ ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
246
+ return ys[:, 1:max_len+1]
247
+
248
+
249
+ # helpers
250
+ def _top_k_filter(logits, k):
251
+ B, L, V = logits.shape
252
+ if k >= V: return logits
253
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
254
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
255
+
256
+ def _batch_multinomial(probs):
257
+ B, L, V = probs.shape
258
+ flat = probs.view(B*L, V) + 1e-9
259
+ return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
260
+
261
+ def _apply_repetition_penalty(logits, prev_tokens, penalty):
262
+ for b in range(logits.shape[0]):
263
+ for tid in set(prev_tokens[b].tolist()):
264
+ if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
265
+ return logits
266
+
267
+ def _apply_diversity_penalty(logits, penalty): # legacy wrong version
268
+ return logits + penalty * logits.var(dim=-1, keepdim=True)
269
+
270
+ def _apply_diversity_penalty_fixed(logits, penalty): # correct version
271
+ return logits - penalty * logits.mean(dim=1, keepdim=True)
d3pm_model_encoder_decoder.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusion.scheduler import OptimizedCosineScheduler
4
+ from diffusion.forward_process import AbsorbingForwardProcess
5
+ # Import shared classes to guarantee identical architectures
6
+ from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
7
+ class DecoderBlock(nn.Module):
8
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
9
+ super().__init__()
10
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
11
+ self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
12
+ self.ff = nn.Sequential(
13
+ nn.Linear(d_model, d_ff),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(d_ff, d_model),
17
+ nn.Dropout(dropout),
18
+ )
19
+ self.norm1 = nn.LayerNorm(d_model)
20
+ self.norm2 = nn.LayerNorm(d_model)
21
+ self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
22
+
23
+ def forward(self, x, memory, tgt_pad_mask=None):
24
+ # 1. Masked self-attention on target
25
+ x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
26
+ # 2. Cross-attention: queries from decoder, keys/values from encoder memory
27
+ x = self.norm2(x + self.cross_attn(x, memory, memory))
28
+ # 3. Feed-forward
29
+ return self.norm3(x + self.ff(x))
30
+
31
+
32
+ class DecoderBlockNoCrossAttn(nn.Module):
33
+ """Kept for reference — NOT used by D3PMEncoderDecoder."""
34
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
35
+ super().__init__()
36
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
37
+ self.ff = nn.Sequential(
38
+ nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
39
+ nn.Linear(d_ff, d_model), nn.Dropout(dropout),
40
+ )
41
+ self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
42
+
43
+ def forward(self, x, tgt_pad_mask=None, causal_mask=None):
44
+ combined_mask = None
45
+ if tgt_pad_mask is not None and causal_mask is not None:
46
+ combined_mask = tgt_pad_mask | causal_mask
47
+ elif causal_mask is not None:
48
+ combined_mask = causal_mask
49
+ elif tgt_pad_mask is not None:
50
+ combined_mask = tgt_pad_mask
51
+ x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
52
+ return self.norm2(x + self.ff(x))
53
+
54
+
55
+ # ============================================================
56
+ # 1. D3PM Encoder-Decoder Model
57
+ # ============================================================
58
+ class D3PMEncoderDecoder(nn.Module):
59
+ def __init__(self, cfg):
60
+ super().__init__()
61
+ self.cfg = cfg
62
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
63
+
64
+ src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
65
+ tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
66
+ d_model = cfg['model']['d_model']
67
+ n_heads = cfg['model']['n_heads']
68
+ d_ff = cfg['model']['d_ff']
69
+ dropout = cfg['model']['dropout']
70
+ n_layers = cfg['model']['n_layers']
71
+ max_len = cfg['model']['max_seq_len']
72
+
73
+ self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
74
+ self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
75
+
76
+ self.scheduler = OptimizedCosineScheduler(cfg)
77
+ self.forward_process = AbsorbingForwardProcess(self.scheduler)
78
+
79
+ self.encoder_blocks = nn.ModuleList([
80
+ EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
81
+ ])
82
+ # DecoderBlock now has cross-attention — matches saved checkpoint
83
+ self.decoder_blocks = nn.ModuleList([
84
+ DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
85
+ ])
86
+
87
+ self.time_mlp = nn.Sequential(
88
+ nn.Linear(1, d_model // 4), nn.SiLU(),
89
+ nn.Linear(d_model // 4, d_model),
90
+ )
91
+ self.head = nn.Linear(d_model, tgt_vocab)
92
+ self.head.weight = self.tgt_embed.token_embedding.weight
93
+
94
+ def forward(self, src, tgt, t, x0_hint=None):
95
+ src_pad_mask = (src == 1)
96
+ tgt_pad_mask = (tgt == 1)
97
+
98
+ # Encode source (Roman IAST)
99
+ memory = self.src_embed(src)
100
+ for block in self.encoder_blocks:
101
+ memory = block(memory, pad_mask=src_pad_mask)
102
+
103
+ # Corrupt target with forward diffusion
104
+ _, x_t_ids = self.forward_process.q_sample(tgt, t)
105
+
106
+ # Optionally blend in x0_hint (self-conditioning)
107
+ if x0_hint is not None:
108
+ hint_prob = 0.5
109
+ blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
110
+ still_mask = (x_t_ids == self.mask_token_id)
111
+ x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
112
+
113
+ x = self.tgt_embed(x_t_ids)
114
+ t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
115
+ x = x + t_emb.expand(-1, tgt.shape[1], -1)
116
+
117
+ # Decode with cross-attention over encoder memory
118
+ for block in self.decoder_blocks:
119
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
120
+
121
+ return self.head(x), None
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ src,
127
+ num_steps = None,
128
+ temperature = 0.75,
129
+ top_k = 50,
130
+ repetition_penalty = 1.15,
131
+ diversity_penalty = 0.0,
132
+ ):
133
+ """
134
+ Iterative D3PM reverse diffusion — same signature as
135
+ D3PMCrossAttention.generate() so SanskritModel.generate() works
136
+ identically for both model types.
137
+ """
138
+ device = src.device
139
+ B, L = src.shape[0], self.cfg['model']['max_seq_len']
140
+ T = num_steps or self.scheduler.num_timesteps
141
+ mask_id = self.mask_token_id
142
+ pad_id = 1
143
+
144
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
145
+
146
+ for step in range(T - 1, -1, -1):
147
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
148
+ hint = x0_est.clone()
149
+
150
+ logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
151
+
152
+ # Repetition penalty
153
+ if repetition_penalty != 1.0:
154
+ for b in range(B):
155
+ for tok in set(x0_est[b].tolist()):
156
+ if tok > pad_id:
157
+ logits[b, :, tok] /= repetition_penalty
158
+
159
+ # Diversity penalty (suppress common tokens)
160
+ if diversity_penalty > 0.0:
161
+ logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
162
+
163
+ # Temperature + top-k sampling
164
+ logits = logits / max(temperature, 1e-8)
165
+ if top_k > 0:
166
+ vals, _ = torch.topk(logits, top_k, dim=-1)
167
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
168
+
169
+ probs = torch.softmax(logits, dim=-1)
170
+ # Only update positions that are still masked
171
+ still = (x0_est == mask_id)
172
+ sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
173
+ x0_est = torch.where(still, sample, x0_est)
174
+
175
+ return x0_est
176
+
177
+
178
+ # ============================================================
179
+ # 2. Baseline Encoder-Decoder Model (unchanged)
180
+ # ============================================================
181
+ class BaselineEncoderDecoder(nn.Module):
182
+ def __init__(self, cfg):
183
+ super().__init__()
184
+ self.cfg = cfg
185
+ self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
186
+ cfg['model']['max_seq_len'])
187
+ self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
188
+ cfg['model']['max_seq_len'])
189
+ self.encoder_blocks = nn.ModuleList([
190
+ EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
191
+ cfg['model']['d_ff'], cfg['model']['dropout'])
192
+ for _ in range(cfg['model']['n_layers'])
193
+ ])
194
+ self.decoder_blocks = nn.ModuleList([
195
+ DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
196
+ cfg['model']['d_ff'], cfg['model']['dropout'])
197
+ for _ in range(cfg['model']['n_layers'])
198
+ ])
199
+ self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
200
+ self.head.weight = self.tgt_embed.token_embedding.weight
201
+
202
+ def forward(self, src, tgt):
203
+ src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
204
+ memory = self.src_embed(src)
205
+ for block in self.encoder_blocks:
206
+ memory = block(memory, pad_mask=src_pad_mask)
207
+ x = self.tgt_embed(tgt)
208
+ for block in self.decoder_blocks:
209
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
210
+ return self.head(x)
211
+
212
+ @torch.no_grad()
213
+ def generate(self, src, max_len=80, start_token_id=2):
214
+ batch_size, device = src.size(0), src.device
215
+ src_pad_mask = (src == 1)
216
+ memory = self.src_embed(src)
217
+ for block in self.encoder_blocks:
218
+ memory = block(memory, pad_mask=src_pad_mask)
219
+ ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
220
+ for _ in range(max_len):
221
+ x = self.tgt_embed(ys)
222
+ for block in self.decoder_blocks:
223
+ x = block(x, memory, tgt_pad_mask=None)
224
+ logits = self.head(x)
225
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
226
+ ys = torch.cat([ys, next_token], dim=1)
227
+ return ys[:, 1:]
dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset.py — Cross-Script Translation Fix
3
+ ==========================================
4
+ INPUT : quote_text (Roman/IAST transliteration of Sanskrit)
5
+ TARGET : quote_devanagari (Devanagari script)
6
+
7
+ This is the CORRECT task: the model learns to transliterate / translate
8
+ Roman Sanskrit → Devanagari, which is a meaningful, learnable mapping
9
+ (far better than devanagari→devanagari reconstruction which teaches nothing).
10
+
11
+ KEY CHANGES from original:
12
+ 1. _input_field = 'quote_text' (was 'quote_devanagari')
13
+ 2. _target_field = 'quote_devanagari' (unchanged)
14
+ 3. Separate source/target tokenizers — Roman and Devanagari have
15
+ completely different character sets; a shared BPE vocab forces the
16
+ model to learn both scripts in one embedding table, which wastes
17
+ capacity and confuses the attention mechanism.
18
+ 4. Negative example generation fixed — reversal now operates on
19
+ DEVANAGARI target only (not accidentally on Roman source).
20
+ 5. curriculum_sort uses target length (Devanagari) for difficulty proxy.
21
+ """
22
+
23
+ from datasets import load_dataset
24
+ from torch.utils.data import Dataset
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import random
28
+
29
+
30
+ class OptimizedSanskritDataset(Dataset):
31
+ def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None,
32
+ src_tokenizer=None, tgt_tokenizer=None):
33
+ """
34
+ Args:
35
+ tokenizer : shared tokenizer (legacy — used if src/tgt not provided)
36
+ src_tokenizer : tokenizer for quote_text (Roman script)
37
+ tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script)
38
+ If None, falls back to shared `tokenizer`.
39
+ """
40
+ from config import CONFIG
41
+ self.cfg = cfg or CONFIG
42
+ self.max_len = max_len
43
+ self.pad_id = 1
44
+ self.mask_id = self.cfg['diffusion']['mask_token_id']
45
+ self.include_negatives = self.cfg['data']['include_negative_examples']
46
+
47
+ # ── Tokenizer setup ───────────────────────────────────────────
48
+ # Support both legacy (shared) and new (separate src/tgt) tokenizers
49
+ self.src_tokenizer = src_tokenizer or tokenizer
50
+ self.tgt_tokenizer = tgt_tokenizer or tokenizer
51
+
52
+ if self.src_tokenizer is None:
53
+ raise ValueError("Provide at least one tokenizer.")
54
+
55
+ print(f"📥 Loading '{split}' split …")
56
+ raw = load_dataset("paws/sanskrit-verses-gretil", split=split)
57
+ cols = raw.column_names
58
+
59
+ # ── Field selection ───────────────────────────────────────────
60
+ if 'quote_text' in cols and 'quote_devanagari' in cols:
61
+ # CORRECT setup: Roman input → Devanagari output
62
+ self._input_field = 'quote_text'
63
+ self._target_field = 'quote_devanagari'
64
+ print(" Format: quote_text (Roman) → quote_devanagari (Devanagari) ✓")
65
+ elif 'sentence1' in cols and 'sentence2' in cols:
66
+ # PAWS paraphrase pairs fallback
67
+ self._input_field = 'sentence1'
68
+ self._target_field = 'sentence2'
69
+ print(" Format: PAWS sentence pairs ✓")
70
+ else:
71
+ # Last resort: same field both sides
72
+ self._input_field = 'quote_devanagari'
73
+ self._target_field = 'quote_devanagari'
74
+ print(" ⚠️ Format: Devanagari→Devanagari (suboptimal — no quote_text found)")
75
+
76
+ # ── Filter empty rows ─────────────────────────────────────────
77
+ # Some rows have empty quote_text — skip them
78
+ raw = raw.filter(
79
+ lambda ex: (
80
+ bool(ex[self._input_field].strip()) and
81
+ bool(ex[self._target_field].strip())
82
+ )
83
+ )
84
+ print(f" After empty-filter: {len(raw)} samples")
85
+
86
+ self.dataset = raw
87
+
88
+ if split == 'train':
89
+ self.dataset = self._curriculum_sort()
90
+
91
+ print(f"✅ {len(self.dataset)} samples loaded.")
92
+
93
+ # ── Encoding ──────────────────────────────────────────────────────
94
+
95
+ def _encode_src(self, text):
96
+ """Encode source (Roman) text."""
97
+ ids = self.src_tokenizer.encode(text)[:self.max_len]
98
+ t = torch.tensor(ids, dtype=torch.long)
99
+ t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
100
+ return t
101
+
102
+ def _encode_tgt(self, text):
103
+ """Encode target (Devanagari) text."""
104
+ ids = self.tgt_tokenizer.encode(text)[:self.max_len]
105
+ t = torch.tensor(ids, dtype=torch.long)
106
+ t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
107
+ return t
108
+
109
+ # ── Curriculum ───��────────────────────────────────────────────────
110
+
111
+ def _curriculum_sort(self):
112
+ """Short, common Devanagari targets first → long, rare targets last."""
113
+ scores = []
114
+ for s in self.dataset:
115
+ text = s[self._target_field]
116
+ length = len(text.split())
117
+ rarity_score = len(set(text)) / max(1, len(text))
118
+ scores.append(length * (1 - rarity_score))
119
+ order = sorted(range(len(self.dataset)), key=lambda i: scores[i])
120
+ return self.dataset.select(order)
121
+
122
+ # ── Item ──────────────────────────────────────────────────────────
123
+
124
+ def __len__(self):
125
+ return len(self.dataset)
126
+
127
+ def __getitem__(self, idx):
128
+ sample = self.dataset[idx]
129
+
130
+ src_text = sample[self._input_field].strip()
131
+ tgt_text = sample[self._target_field].strip()
132
+
133
+ input_ids = self._encode_src(src_text) # Roman encoded with src_tokenizer
134
+ target_ids = self._encode_tgt(tgt_text) # Devanagari encoded with tgt_tokenizer
135
+
136
+ out = {
137
+ 'input_ids': input_ids,
138
+ 'target_ids': target_ids,
139
+ 'input_text': src_text,
140
+ 'target_text': tgt_text,
141
+ }
142
+
143
+ if self.include_negatives:
144
+ neg_ids = target_ids.clone()
145
+ # Reverse a random chunk of the DEVANAGARI target
146
+ non_pad = (neg_ids != self.pad_id).sum().item()
147
+ if non_pad > 4:
148
+ i1, i2 = sorted(random.sample(range(non_pad), 2))
149
+ neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0])
150
+ out['negative_target_ids'] = neg_ids
151
+
152
+ return out
forward_process.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ forward_process.py — Verified Correct (no changes needed)
3
+ ===========================================================
4
+ Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
5
+ so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
6
+ """
7
+ import torch
8
+
9
+ class AbsorbingForwardProcess:
10
+ def __init__(self, scheduler, mask_id=0, pad_id=1):
11
+ self.scheduler = scheduler
12
+ self.mask_id = mask_id
13
+ self.pad_id = pad_id
14
+
15
+ def q_sample(self, x_0, t):
16
+ alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
17
+ r = torch.rand(x_0.shape, device=x_0.device)
18
+ x_t = x_0.clone()
19
+ x_t[r > alpha_t] = self.mask_id
20
+ x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
21
+ return x_0, x_t
inference.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ ============
4
+ Correct D3PM inference for Sanskrit paraphrase generation.
5
+
6
+ The model's forward() takes CLEAN tgt and noises it internally.
7
+ So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
8
+ letting the model noise it and then predict a cleaner version.
9
+
10
+ Also includes: robust checkpoint loading (auto-detects architecture
11
+ from saved weights — no CONFIG mismatch crashes).
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import os, sys
17
+ from tqdm import tqdm
18
+ from torch.utils.data import DataLoader, Subset
19
+
20
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
21
+ from config import CONFIG
22
+
23
+
24
+ # ── Checkpoint loader ─────────────────────────────────────────────────
25
+
26
+ def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
27
+ """
28
+ Auto-detect architecture from checkpoint weight shapes,
29
+ then load. Never fails due to CONFIG vs checkpoint mismatch.
30
+ """
31
+ import copy
32
+ from model.sanskrit_model import SanskritModel
33
+
34
+ cfg = copy.deepcopy(base_cfg)
35
+ state = torch.load(ckpt_path, map_location='cpu')
36
+
37
+ # d_model + vocab_size
38
+ ek = 'model.src_embed.token_emb.weight'
39
+ if ek in state:
40
+ vocab, d = state[ek].shape
41
+ cfg['model']['vocab_size'] = vocab
42
+ cfg['model']['d_model'] = d
43
+ cfg['model']['d_ff'] = d * 4
44
+
45
+ # n_layers
46
+ ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
47
+ if ids:
48
+ cfg['model']['n_layers'] = max(ids) + 1
49
+
50
+ # max_seq_len
51
+ pk = 'model.src_embed.pos_enc.pe'
52
+ if pk in state:
53
+ cfg['model']['max_seq_len'] = state[pk].shape[1]
54
+
55
+ # n_heads
56
+ d = cfg['model']['d_model']
57
+ h = cfg['model'].get('n_heads', 6)
58
+ if d % h != 0:
59
+ h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
60
+ cfg['model']['n_heads'] = h
61
+
62
+ print(f"🔍 Detected: d_model={cfg['model']['d_model']}, "
63
+ f"n_layers={cfg['model']['n_layers']}, "
64
+ f"max_seq_len={cfg['model']['max_seq_len']}, "
65
+ f"n_heads={cfg['model']['n_heads']}")
66
+
67
+ model = SanskritModel(cfg).to(device)
68
+ missing, unexpected = model.load_state_dict(
69
+ torch.load(ckpt_path, map_location=device), strict=False
70
+ )
71
+
72
+ # hint_gate may be absent in older checkpoints — initialise safely
73
+ allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
74
+ real_missing = [k for k in missing if k not in allowed]
75
+ if real_missing:
76
+ print(f"⚠️ Missing keys: {real_missing[:3]} …")
77
+ if unexpected:
78
+ print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
79
+ if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
80
+ with torch.no_grad():
81
+ w = model.model.hint_gate[0].weight
82
+ torch.nn.init.zeros_(model.model.hint_gate[0].bias)
83
+ torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
84
+ else torch.nn.init.xavier_uniform_(w)
85
+ print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
86
+
87
+ print("✅ Model loaded.")
88
+ return model, cfg
89
+
90
+
91
+ # ── Core inference function ───────────────────────────────────────────
92
+
93
+ def run_inference(model, input_ids, cfg):
94
+ """
95
+ Correct D3PM iterative refinement.
96
+
97
+ x0_est starts as all [MASK].
98
+ Each step: model(src, x0_est, t) noises x0_est internally,
99
+ then predicts a cleaner version. x0_est is updated each step.
100
+ """
101
+ inf = cfg['inference']
102
+ device = input_ids.device
103
+ B, L = input_ids.shape
104
+
105
+ inner = model.model
106
+ T = inner.scheduler.num_timesteps
107
+ steps = inf['num_steps'] # must equal T (set in config)
108
+ step_size = max(1, T // steps)
109
+ timesteps = list(range(T - 1, -1, -step_size))
110
+ if timesteps[-1] != 0:
111
+ timesteps.append(0)
112
+
113
+ mask_id = inner.mask_token_id
114
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
115
+ hint = None
116
+
117
+ model.eval()
118
+ with torch.no_grad():
119
+ for step_idx, t_val in enumerate(timesteps):
120
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
121
+ is_last = (step_idx == len(timesteps) - 1)
122
+
123
+ logits, _ = model(input_ids, x0_est, t, x0_hint=hint)
124
+
125
+ # Penalties
126
+ if inf['repetition_penalty'] != 1.0:
127
+ from model.d3pm_model_cross_attention import _apply_repetition_penalty
128
+ logits = _apply_repetition_penalty(
129
+ logits, x0_est, inf['repetition_penalty']
130
+ )
131
+ if inf['diversity_penalty'] > 0.0:
132
+ from model.d3pm_model_cross_attention import _apply_diversity_penalty
133
+ logits = _apply_diversity_penalty(logits, inf['diversity_penalty'])
134
+
135
+ logits = logits / max(inf['temperature'], 1e-5)
136
+ if inf['top_k'] > 0:
137
+ from model.d3pm_model_cross_attention import _top_k_filter
138
+ logits = _top_k_filter(logits, inf['top_k'])
139
+
140
+ probs = F.softmax(logits, dim=-1)
141
+
142
+ if is_last:
143
+ x0_est = torch.argmax(probs, dim=-1)
144
+ else:
145
+ from model.d3pm_model_cross_attention import _batch_multinomial
146
+ x0_est = _batch_multinomial(probs)
147
+
148
+ hint = x0_est
149
+
150
+ return x0_est
151
+
152
+
153
+ # ── Interactive demo ──────────────────────────────────────────────────
154
+
155
+ def interactive_demo():
156
+ from model.tokenizer import SanskritTokenizer
157
+
158
+ cfg = CONFIG
159
+ device = torch.device(cfg['training']['device'])
160
+
161
+ model_name = cfg['model_type']
162
+ has_neg = cfg['data']['include_negative_examples']
163
+ ckpt = f"results/{model_name}_neg_{has_neg}/best_model.pt"
164
+
165
+ if not os.path.exists(ckpt):
166
+ raise FileNotFoundError(f"No checkpoint at {ckpt} — train first.")
167
+
168
+ model, cfg = load_model(ckpt, cfg, device)
169
+ model.eval()
170
+
171
+ tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
172
+ PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
173
+ MASK_ID = cfg['diffusion']['mask_token_id']
174
+
175
+ print("\n" + "="*55)
176
+ print("Sanskrit D3PM Paraphrase — type verse, get paraphrase")
177
+ print("="*55 + "\n")
178
+
179
+ while True:
180
+ try:
181
+ text = input("INPUT > ").strip()
182
+ except (EOFError, KeyboardInterrupt):
183
+ break
184
+ if not text or text.lower() in ('quit', 'exit', 'q'):
185
+ break
186
+
187
+ ids = torch.tensor(
188
+ [tokenizer.encode(text)[:cfg['model']['max_seq_len']]],
189
+ dtype=torch.long, device=device
190
+ )
191
+ out = run_inference(model, ids, cfg)
192
+ clean = [i for i in out[0].tolist() if i not in (MASK_ID, PAD_ID)]
193
+ print(f"PARAPHRASE → {tokenizer.decode(clean).strip()}\n")
194
+
195
+
196
+ # ── Batch evaluation ──────────────────────────────────────────────────
197
+
198
+ def batch_evaluate(sample_size=500):
199
+ from data.dataset import OptimizedSanskritDataset
200
+ from model.tokenizer import SanskritTokenizer
201
+
202
+ cfg = CONFIG
203
+ device = torch.device(cfg['training']['device'])
204
+
205
+ model_name = cfg['model_type']
206
+ has_neg = cfg['data']['include_negative_examples']
207
+ exp_dir = f"results/{model_name}_neg_{has_neg}"
208
+ ckpt = f"{exp_dir}/best_model.pt"
209
+
210
+ if not os.path.exists(ckpt):
211
+ raise FileNotFoundError(f"No checkpoint at {ckpt}")
212
+
213
+ model, cfg = load_model(ckpt, cfg, device)
214
+ model.eval()
215
+
216
+ tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
217
+ PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
218
+ MASK_ID = cfg['diffusion']['mask_token_id']
219
+
220
+ def collate(batch):
221
+ return {
222
+ 'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
223
+ 'target_text': [b['target_text'] for b in batch],
224
+ 'input_text': [b['input_text'] for b in batch],
225
+ }
226
+
227
+ dataset = OptimizedSanskritDataset('test', tokenizer, cfg['model']['max_seq_len'], cfg)
228
+ indices = list(range(min(sample_size, len(dataset))))
229
+ loader = DataLoader(
230
+ Subset(dataset, indices),
231
+ batch_size=cfg['training']['batch_size'],
232
+ shuffle=False, collate_fn=collate
233
+ )
234
+
235
+ all_preds, all_refs, all_inputs = [], [], []
236
+ print(f"⏳ Generating {len(indices)} paraphrases …")
237
+
238
+ for batch in tqdm(loader):
239
+ ids = batch['input_ids'].to(device)
240
+ out = run_inference(model, ids, cfg)
241
+ for i in range(out.size(0)):
242
+ clean = [x for x in out[i].tolist() if x not in (MASK_ID, PAD_ID)]
243
+ all_preds.append(tokenizer.decode(clean).strip())
244
+ all_refs.append(batch['target_text'][i].strip())
245
+ all_inputs.append(batch['input_text'][i].strip())
246
+
247
+ # Metrics
248
+ bleu_score, bert_f1 = 0.0, 0.0
249
+ try:
250
+ from nltk.translate.bleu_score import corpus_bleu
251
+ bleu_score = corpus_bleu(
252
+ [[r.split()] for r in all_refs],
253
+ [p.split() for p in all_preds]
254
+ )
255
+ except Exception:
256
+ pass
257
+
258
+ try:
259
+ import evaluate as hf_eval
260
+ res = hf_eval.load('bertscore').compute(
261
+ predictions=all_preds, references=all_refs, lang='hi'
262
+ )
263
+ bert_f1 = sum(res['f1']) / len(res['f1'])
264
+ except Exception:
265
+ pass
266
+
267
+ # Save
268
+ out_path = f"{exp_dir}/evaluation_results.txt"
269
+ with open(out_path, 'w', encoding='utf-8') as f:
270
+ f.write(f"Model : {model_name}\n")
271
+ f.write(f"Negatives: {has_neg}\n")
272
+ f.write(f"Steps : {cfg['inference']['num_steps']}\n")
273
+ f.write(f"Temp : {cfg['inference']['temperature']}\n")
274
+ f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
275
+ f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
276
+ f.write(f"BLEU : {bleu_score:.4f}\n")
277
+ f.write(f"BERTScore: {bert_f1:.4f}\n\n")
278
+ f.write("=== SAMPLES ===\n")
279
+ for i in range(min(20, len(all_preds))):
280
+ f.write(f"IN : {all_inputs[i]}\n")
281
+ f.write(f"REF : {all_refs[i]}\n")
282
+ f.write(f"PRED: {all_preds[i]}\n")
283
+ f.write("-" * 60 + "\n")
284
+
285
+ print(f"\n✅ Results → {out_path}")
286
+ print(f"📊 BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
287
+ return all_preds, all_refs
288
+
289
+
290
+ if __name__ == '__main__':
291
+ import argparse
292
+ p = argparse.ArgumentParser()
293
+ p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
294
+ p.add_argument('--samples', type=int, default=500)
295
+ args = p.parse_args()
296
+
297
+ if args.mode == 'demo':
298
+ interactive_demo()
299
+ else:
300
+ batch_evaluate(args.samples)
kv_cache_benchmark.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/kv_cache_benchmark.py
3
+ ================================
4
+ Task 1: Benchmark KV cache vs standard generate().
5
+
6
+ Measures:
7
+ - Wall-clock time for generate() vs generate_cached()
8
+ - Encoder time as % of total generation time (before/after)
9
+ - Speedup ratio at src_len = 16, 32, 64 tokens
10
+
11
+ How it works:
12
+ Standard generate():
13
+ For each of T=128 steps:
14
+ src → encoder → memory → decoder → logits (encoder runs 128 times)
15
+
16
+ generate_cached():
17
+ src → encoder → memory (once)
18
+ For each of T=128 steps:
19
+ cached_memory → decoder → logits (encoder runs 1 time)
20
+
21
+ Expected speedup:
22
+ If encoder = 30% of per-step time:
23
+ Saved = 127/128 * 30% ≈ 29.7% of total time
24
+ If encoder = 50% of per-step time:
25
+ Saved ≈ 49.6% of total time
26
+
27
+ Usage:
28
+ python -m analysis.kv_cache_benchmark
29
+ or:
30
+ from analysis.kv_cache_benchmark import run_benchmark
31
+ results = run_benchmark(model, src_tokenizer, device)
32
+ """
33
+
34
+ import torch
35
+ import time
36
+ import numpy as np
37
+ from typing import Dict, List
38
+
39
+
40
+ def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
41
+ """Create a random source tensor of given length."""
42
+ # Random real tokens (ids 5..src_vocab-1), padded to src_len
43
+ ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
44
+ return ids
45
+
46
+
47
+ def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
48
+ """
49
+ Time a zero-argument callable.
50
+ Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
51
+ """
52
+ # Warmup
53
+ for _ in range(n_warmup):
54
+ fn()
55
+ if torch.cuda.is_available():
56
+ torch.cuda.synchronize()
57
+ elif torch.backends.mps.is_available():
58
+ torch.mps.synchronize()
59
+
60
+ times = []
61
+ for _ in range(n_runs):
62
+ start = time.perf_counter()
63
+ fn()
64
+ if torch.cuda.is_available():
65
+ torch.cuda.synchronize()
66
+ elif torch.backends.mps.is_available():
67
+ torch.mps.synchronize()
68
+ times.append(time.perf_counter() - start)
69
+
70
+ return float(np.mean(times))
71
+
72
+
73
+ def benchmark_encoder_cost(
74
+ model,
75
+ src: torch.Tensor,
76
+ ) -> Dict[str, float]:
77
+ """
78
+ Measure encoder time as a fraction of one full forward pass.
79
+
80
+ Returns:
81
+ encoder_s : seconds for one encoder call
82
+ full_step_s : seconds for one full forward_cached decoder step
83
+ encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
84
+ """
85
+ inner = model.model
86
+ if not hasattr(inner, 'encode_source'):
87
+ raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
88
+
89
+ device = src.device
90
+ B = src.shape[0]
91
+ T = inner.scheduler.num_timesteps
92
+ tgt_len = inner.max_seq_len
93
+ mask_id = inner.mask_token_id
94
+
95
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
96
+ t = torch.zeros(B, dtype=torch.long, device=device)
97
+
98
+ # Time encoder alone
99
+ encoder_s = _time_fn(lambda: inner.encode_source(src))
100
+
101
+ # Pre-compute memory for decoder timing
102
+ memory, src_pad_mask = inner.encode_source(src)
103
+
104
+ # Time one decoder step (cached)
105
+ decoder_s = _time_fn(
106
+ lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
107
+ inference_mode=True)
108
+ )
109
+
110
+ # Time one full step (non-cached = encoder + decoder)
111
+ full_s = _time_fn(
112
+ lambda: inner.forward(src, x0_est, t, inference_mode=True)
113
+ )
114
+
115
+ encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
116
+
117
+ return {
118
+ "encoder_s": encoder_s,
119
+ "decoder_s": decoder_s,
120
+ "full_step_s": full_s,
121
+ "encoder_pct": encoder_pct,
122
+ }
123
+
124
+
125
+ def run_benchmark(
126
+ model,
127
+ src_tokenizer,
128
+ device: torch.device,
129
+ src_lens: List[int] = [16, 32, 64],
130
+ n_runs: int = 5,
131
+ ) -> Dict:
132
+ """
133
+ Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
134
+
135
+ Args:
136
+ model : SanskritModel (D3PMCrossAttention)
137
+ src_tokenizer : SanskritSourceTokenizer
138
+ device : torch.device
139
+ src_lens : list of source lengths to benchmark
140
+ n_runs : number of timing runs per condition
141
+
142
+ Returns:
143
+ results dict with timing and speedup for each src_len
144
+ """
145
+ inner = model.model
146
+ if not hasattr(inner, 'generate_cached'):
147
+ raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
148
+
149
+ src_vocab = inner.src_embed.token_emb.weight.shape[0]
150
+ results = {}
151
+
152
+ print("\n" + "=" * 65)
153
+ print(" KV CACHE BENCHMARK")
154
+ print("=" * 65)
155
+ print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
156
+ f"{'speedup':>8} {'encoder%':>9}")
157
+ print("-" * 65)
158
+
159
+ for src_len in src_lens:
160
+ src = _make_src(src_len, src_vocab, device)
161
+
162
+ # Encoder cost breakdown
163
+ enc_cost = benchmark_encoder_cost(model, src)
164
+
165
+ # Time standard generate() — encoder runs T times
166
+ def run_standard():
167
+ return inner.generate(src, temperature=0.8, top_k=40)
168
+
169
+ # Time generate_cached() — encoder runs once
170
+ def run_cached():
171
+ return inner.generate_cached(src, temperature=0.8, top_k=40)
172
+
173
+ t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
174
+ t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
175
+ speedup = t_standard / max(t_cached, 1e-9)
176
+
177
+ results[src_len] = {
178
+ "standard_s": t_standard,
179
+ "cached_s": t_cached,
180
+ "speedup": speedup,
181
+ "encoder_pct": enc_cost["encoder_pct"],
182
+ }
183
+
184
+ print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
185
+ f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
186
+
187
+ print("=" * 65)
188
+ print(f"\n Encoder cost = % of one full forward pass")
189
+ print(f" Speedup = standard_time / cached_time")
190
+ print(f" Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)")
191
+
192
+ return results
193
+
194
+
195
+ def print_summary(results: Dict):
196
+ """Print a human-readable summary of benchmark results."""
197
+ print("\n SUMMARY")
198
+ print(" -------")
199
+ for src_len, r in results.items():
200
+ saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
201
+ print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
202
+ f"({saved_pct:.1f}% time saved, "
203
+ f"encoder was {r['encoder_pct']:.1f}% of total)")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ import sys, os
208
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
209
+ from config import CONFIG
210
+ from inference import load_model
211
+ from models.tokenizer import SanskritSourceTokenizer
212
+
213
+ cfg = CONFIG
214
+ device = torch.device(cfg['training']['device'])
215
+
216
+ model_name = cfg['model_type']
217
+ has_neg = cfg['data']['include_negative_examples']
218
+ ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
219
+
220
+ if not os.path.exists(ckpt):
221
+ print(f"No checkpoint at {ckpt}. Train first.")
222
+ sys.exit(1)
223
+
224
+ model, cfg = load_model(ckpt, cfg, device)
225
+ model.eval()
226
+
227
+ src_tokenizer = SanskritSourceTokenizer(
228
+ vocab_size = cfg['model'].get('src_vocab_size', 500),
229
+ max_len = cfg['model']['max_seq_len'],
230
+ )
231
+
232
+ results = run_benchmark(model, src_tokenizer, device)
233
+ print_summary(results)
quality_classifier.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/quality_classifier.py
3
+ ================================
4
+ Task 5: Classifier-Free Guidance for Paraphrase Quality Control
5
+
6
+ Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
7
+
8
+ STEP 1 — Collect training data (no training):
9
+ Run existing model on val set, record (hidden_state, CER) pairs.
10
+ Hidden states come from model.model._last_hidden after forward_cached().
11
+ CER score = quality label (lower CER = higher quality).
12
+
13
+ STEP 2 — Train quality classifier:
14
+ Small 2-layer MLP: d_model → 64 → 1
15
+ Input: pooled decoder hidden state [B, d_model]
16
+ Output: predicted quality score in [0, 1] (1 = high quality)
17
+ Loss: MSE against normalized CER labels
18
+ Training time: ~5-10 minutes on CPU for 10k examples
19
+
20
+ STEP 3 — Guided inference (no retraining):
21
+ At each diffusion step, use classifier gradient to shift logits:
22
+ guided_logits = logits + λ * ∂(quality_score)/∂(logits)
23
+ Higher λ → model biased toward high-quality outputs
24
+ λ=0 → standard generation (no guidance)
25
+
26
+ Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
27
+ """
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import numpy as np
33
+ import os
34
+ import json
35
+ from typing import List, Dict, Optional, Tuple
36
+
37
+
38
+ # ── Quality classifier architecture ──────────────────────────────────
39
+
40
+ class QualityClassifier(nn.Module):
41
+ """
42
+ Lightweight MLP that predicts transliteration quality from decoder
43
+ hidden states.
44
+
45
+ Architecture:
46
+ d_model → 128 → 64 → 1 → Sigmoid
47
+
48
+ Input: mean-pooled decoder hidden state [B, d_model]
49
+ Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
50
+
51
+ ~10k parameters. Trains in minutes on CPU.
52
+ """
53
+ def __init__(self, d_model: int):
54
+ super().__init__()
55
+ self.net = nn.Sequential(
56
+ nn.Linear(d_model, 128),
57
+ nn.ReLU(),
58
+ nn.Dropout(0.1),
59
+ nn.Linear(128, 64),
60
+ nn.ReLU(),
61
+ nn.Linear(64, 1),
62
+ nn.Sigmoid(),
63
+ )
64
+ self.d_model = d_model
65
+
66
+ def forward(self, hidden: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Args:
69
+ hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
70
+
71
+ Returns:
72
+ score : [B, 1] quality score in [0, 1]
73
+ """
74
+ if hidden.dim() == 3:
75
+ # Pool over sequence length
76
+ hidden = hidden.mean(dim=1) # [B, d_model]
77
+ return self.net(hidden) # [B, 1]
78
+
79
+
80
+ # ── Training data collection ──────────────────────────────────────────
81
+
82
+ @torch.no_grad()
83
+ def collect_quality_data(
84
+ model,
85
+ src_list: List[torch.Tensor],
86
+ ref_list: List[str],
87
+ tgt_tokenizer,
88
+ t_capture: int = 0,
89
+ temperature: float = 0.8,
90
+ top_k: int = 40,
91
+ max_samples: int = 5000,
92
+ ) -> Tuple[np.ndarray, np.ndarray]:
93
+ """
94
+ Collect (hidden_state, quality_score) pairs for classifier training.
95
+
96
+ For each sample:
97
+ 1. Run generate_cached() on src
98
+ 2. Capture decoder hidden state at t=t_capture
99
+ 3. Compute CER between output and reference
100
+ 4. Quality = 1 - CER (normalize to [0,1])
101
+
102
+ Args:
103
+ model : SanskritModel
104
+ src_list : list of [1, src_len] tensors
105
+ ref_list : list of reference Devanagari strings
106
+ tgt_tokenizer : SanskritTargetTokenizer
107
+ t_capture : which step to capture hidden states (0 = final)
108
+ max_samples : cap number of training examples
109
+
110
+ Returns:
111
+ hidden_matrix : np.ndarray [N, d_model]
112
+ quality_scores: np.ndarray [N] values in [0, 1]
113
+ """
114
+ inner = model.model
115
+ T = inner.scheduler.num_timesteps
116
+ device = next(inner.parameters()).device
117
+
118
+ hidden_list = []
119
+ quality_list = []
120
+ n = min(len(src_list), max_samples)
121
+
122
+ def cer(pred, ref):
123
+ if not ref:
124
+ return 1.0
125
+ def ed(s1, s2):
126
+ m, n = len(s1), len(s2)
127
+ dp = list(range(n + 1))
128
+ for i in range(1, m + 1):
129
+ prev, dp[0] = dp[0], i
130
+ for j in range(1, n + 1):
131
+ temp = dp[j]
132
+ dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
133
+ prev = temp
134
+ return dp[n]
135
+ return ed(pred, ref) / max(len(ref), 1)
136
+
137
+ print(f"Collecting quality data from {n} examples...")
138
+ for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
139
+ if i % 200 == 0:
140
+ print(f" {i}/{n}")
141
+
142
+ if src.dim() == 1:
143
+ src = src.unsqueeze(0)
144
+ src = src.to(device)
145
+
146
+ B = src.shape[0]
147
+ tgt_len = inner.max_seq_len
148
+ mask_id = inner.mask_token_id
149
+
150
+ memory, src_pad_mask = inner.encode_source(src)
151
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
152
+ hint = None
153
+ h_cap = None
154
+
155
+ for t_val in range(T - 1, -1, -1):
156
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
157
+ is_last = (t_val == 0)
158
+
159
+ logits, _ = inner.forward_cached(
160
+ memory, src_pad_mask, x0_est, t,
161
+ x0_hint=hint, inference_mode=True,
162
+ )
163
+
164
+ if t_val == t_capture and hasattr(inner, '_last_hidden'):
165
+ h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
166
+
167
+ logits = logits / max(temperature, 1e-8)
168
+ if top_k > 0:
169
+ V = logits.shape[-1]
170
+ if top_k < V:
171
+ vals, _ = torch.topk(logits, top_k, dim=-1)
172
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
173
+
174
+ probs = F.softmax(logits, dim=-1)
175
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
176
+ hint = x0_est
177
+
178
+ if h_cap is None:
179
+ continue
180
+
181
+ ids = [x for x in x0_est[0].tolist() if x > 4]
182
+ pred = tgt_tokenizer.decode(ids).strip()
183
+ q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
184
+
185
+ hidden_list.append(h_cap.numpy())
186
+ quality_list.append(q)
187
+
188
+ print(f"Collected {len(hidden_list)} quality examples.")
189
+ print(f"Quality stats: mean={np.mean(quality_list):.3f} "
190
+ f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
191
+
192
+ return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
193
+
194
+
195
+ def _sample(probs):
196
+ B, L, V = probs.shape
197
+ flat = probs.view(B * L, V).clamp(min=1e-9)
198
+ flat = flat / flat.sum(dim=-1, keepdim=True)
199
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
200
+
201
+
202
+ # ── Training ──────────────────────────────────────────────────────────
203
+
204
+ def train_quality_classifier(
205
+ hidden_matrix: np.ndarray,
206
+ quality_scores: np.ndarray,
207
+ d_model: int,
208
+ epochs: int = 30,
209
+ batch_size: int = 64,
210
+ lr: float = 1e-3,
211
+ val_frac: float = 0.1,
212
+ save_path: Optional[str] = None,
213
+ ) -> QualityClassifier:
214
+ """
215
+ Train QualityClassifier on collected (hidden, quality) pairs.
216
+
217
+ Args:
218
+ hidden_matrix : [N, d_model] from collect_quality_data()
219
+ quality_scores : [N] quality labels in [0, 1]
220
+ d_model : hidden dimension
221
+ epochs : training epochs
222
+ save_path : if given, save trained classifier weights here
223
+
224
+ Returns:
225
+ trained QualityClassifier
226
+ """
227
+ device = torch.device("cpu") # classifier is tiny, CPU is fine
228
+
229
+ X = torch.tensor(hidden_matrix, dtype=torch.float32)
230
+ y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
231
+
232
+ N = len(X)
233
+ n_val = max(1, int(N * val_frac))
234
+ idx = torch.randperm(N)
235
+ val_idx = idx[:n_val]
236
+ train_idx = idx[n_val:]
237
+
238
+ X_train, y_train = X[train_idx], y[train_idx]
239
+ X_val, y_val = X[val_idx], y[val_idx]
240
+
241
+ clf = QualityClassifier(d_model).to(device)
242
+ optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
243
+
244
+ print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
245
+ print(f"Train: {len(X_train)} Val: {len(X_val)}")
246
+
247
+ best_val_loss = float('inf')
248
+ best_state = None
249
+
250
+ for epoch in range(epochs):
251
+ clf.train()
252
+ perm = torch.randperm(len(X_train))
253
+ train_loss = 0.0
254
+ n_batches = 0
255
+
256
+ for start in range(0, len(X_train), batch_size):
257
+ batch_idx = perm[start:start + batch_size]
258
+ xb, yb = X_train[batch_idx], y_train[batch_idx]
259
+ pred = clf(xb)
260
+ loss = F.mse_loss(pred, yb)
261
+ optimizer.zero_grad()
262
+ loss.backward()
263
+ optimizer.step()
264
+ train_loss += loss.item()
265
+ n_batches += 1
266
+
267
+ clf.eval()
268
+ with torch.no_grad():
269
+ val_pred = clf(X_val)
270
+ val_loss = F.mse_loss(val_pred, y_val).item()
271
+
272
+ if epoch % 5 == 0 or epoch == epochs - 1:
273
+ print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
274
+
275
+ if val_loss < best_val_loss:
276
+ best_val_loss = val_loss
277
+ best_state = {k: v.clone() for k, v in clf.state_dict().items()}
278
+
279
+ if best_state:
280
+ clf.load_state_dict(best_state)
281
+ print(f" Best val loss: {best_val_loss:.4f}")
282
+
283
+ if save_path:
284
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
285
+ torch.save(clf.state_dict(), save_path)
286
+ print(f" Classifier saved: {save_path}")
287
+
288
+ return clf
289
+
290
+
291
+ # ── Guided inference ──────────────────────────────────────────────────
292
+
293
+ def generate_guided(
294
+ model,
295
+ src: torch.Tensor,
296
+ classifier: QualityClassifier,
297
+ guidance_scale: float = 1.0,
298
+ temperature: float = 0.8,
299
+ top_k: int = 40,
300
+ ) -> torch.Tensor:
301
+ """
302
+ Classifier-guided generation.
303
+
304
+ At each diffusion step:
305
+ 1. Run forward_cached() → logits, hidden states
306
+ 2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
307
+ 3. Project gradient back to logit space (approximate)
308
+ 4. guided_logits = logits + λ * gradient_signal
309
+ 5. Sample from guided_logits
310
+
311
+ guidance_scale λ:
312
+ 0.0 → no guidance (standard generation)
313
+ 0.5 → weak guidance
314
+ 1.0 → moderate guidance (recommended starting point)
315
+ 2.0 → strong guidance (may reduce diversity)
316
+ 3.0 → very strong (may collapse to repetitive output)
317
+
318
+ Args:
319
+ model : SanskritModel (frozen)
320
+ src : [1, src_len] IAST token ids
321
+ classifier : trained QualityClassifier
322
+ guidance_scale : λ — guidance strength
323
+
324
+ Returns:
325
+ x0_est : [1, tgt_len] generated token ids
326
+ """
327
+ inner = model.model
328
+ T = inner.scheduler.num_timesteps
329
+ device = next(inner.parameters()).device
330
+ clf_device = next(classifier.parameters()).device
331
+
332
+ if src.dim() == 1:
333
+ src = src.unsqueeze(0)
334
+ src = src.to(device)
335
+
336
+ B = src.shape[0]
337
+ tgt_len = inner.max_seq_len
338
+ mask_id = inner.mask_token_id
339
+
340
+ memory, src_pad_mask = inner.encode_source(src)
341
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
342
+ hint = None
343
+
344
+ inner.eval()
345
+ classifier.eval()
346
+
347
+ for t_val in range(T - 1, -1, -1):
348
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
349
+ is_last = (t_val == 0)
350
+
351
+ if guidance_scale > 0.0:
352
+ # Need gradients for classifier guidance
353
+ with torch.enable_grad():
354
+ # Run forward_cached and get hidden states
355
+ PAD = 1
356
+ if t_val > 0:
357
+ _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
358
+ else:
359
+ x_t_ids = x0_est
360
+
361
+ x = inner.tgt_embed(x_t_ids)
362
+ t_norm = t.float() / T
363
+ t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
364
+ x = x + t_emb.unsqueeze(1)
365
+
366
+ if hint is not None:
367
+ hint_emb = inner.tgt_embed(hint)
368
+ gate = inner.hint_gate(x)
369
+ x = x + gate * hint_emb
370
+
371
+ for block in inner.decoder_blocks:
372
+ x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
373
+
374
+ # hidden: [B, tgt_len, d_model] — detach from graph for clf
375
+ hidden = x.detach().requires_grad_(True).to(clf_device)
376
+
377
+ # Classifier quality score
378
+ quality = classifier(hidden) # [B, 1]
379
+ quality.sum().backward()
380
+
381
+ # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
382
+ grad = hidden.grad.to(device) # [B, tgt_len, d_model]
383
+
384
+ # Project gradient to logit space via output head weight
385
+ # logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
386
+ logit_grad = grad @ inner.head.weight.T
387
+
388
+ # Compute standard logits (no gradient needed)
389
+ with torch.no_grad():
390
+ logits = inner.head(x)
391
+
392
+ # Apply guidance
393
+ logits = logits + guidance_scale * logit_grad
394
+
395
+ else:
396
+ with torch.no_grad():
397
+ logits, _ = inner.forward_cached(
398
+ memory, src_pad_mask, x0_est, t,
399
+ x0_hint=hint, inference_mode=True,
400
+ )
401
+
402
+ with torch.no_grad():
403
+ logits = logits / max(temperature, 1e-8)
404
+ if top_k > 0:
405
+ V = logits.shape[-1]
406
+ if top_k < V:
407
+ vals, _ = torch.topk(logits, top_k, dim=-1)
408
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
409
+
410
+ probs = F.softmax(logits, dim=-1)
411
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
412
+ hint = x0_est
413
+
414
+ return x0_est
415
+
416
+
417
+ def _sample_no_grad(probs):
418
+ B, L, V = probs.shape
419
+ flat = probs.view(B * L, V).clamp(min=1e-9)
420
+ flat = flat / flat.sum(dim=-1, keepdim=True)
421
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
422
+
423
+
424
+ # ── Guidance scale sweep ──────────────────────────────────────────────
425
+
426
+ def sweep_guidance_scales(
427
+ model,
428
+ classifier: QualityClassifier,
429
+ src_list: List[torch.Tensor],
430
+ ref_list: List[str],
431
+ tgt_tokenizer,
432
+ scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
433
+ n_samples: int = 50,
434
+ device: torch.device = None,
435
+ output_dir: str = "analysis/outputs",
436
+ ) -> Dict:
437
+ """
438
+ Evaluate CER at each guidance scale.
439
+ Produces quality-diversity tradeoff plot.
440
+ """
441
+ def cer(pred, ref):
442
+ if not ref:
443
+ return 1.0
444
+ def ed(s1, s2):
445
+ m, n = len(s1), len(s2)
446
+ dp = list(range(n + 1))
447
+ for i in range(1, m + 1):
448
+ prev, dp[0] = dp[0], i
449
+ for j in range(1, n + 1):
450
+ temp = dp[j]
451
+ dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
452
+ prev = temp
453
+ return dp[n]
454
+ return ed(pred, ref) / max(len(ref), 1)
455
+
456
+ device = device or next(model.parameters()).device
457
+ results = {}
458
+ n = min(n_samples, len(src_list))
459
+
460
+ print("\nGuidance scale sweep...")
461
+ for scale in scales:
462
+ cer_list = []
463
+ output_set = []
464
+ for src, ref in zip(src_list[:n], ref_list[:n]):
465
+ if src.dim() == 1:
466
+ src = src.unsqueeze(0)
467
+ out = generate_guided(model, src.to(device), classifier,
468
+ guidance_scale=scale)
469
+ ids = [x for x in out[0].tolist() if x > 4]
470
+ pred = tgt_tokenizer.decode(ids).strip()
471
+ cer_list.append(cer(pred, ref))
472
+ output_set.append(pred)
473
+
474
+ mean_cer = float(np.mean(cer_list))
475
+
476
+ # Self-diversity: unique outputs / total (proxy for diversity)
477
+ unique_frac = len(set(output_set)) / max(len(output_set), 1)
478
+
479
+ results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
480
+ print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
481
+
482
+ # Plot
483
+ os.makedirs(output_dir, exist_ok=True)
484
+ try:
485
+ import matplotlib.pyplot as plt
486
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
487
+
488
+ sc_list = sorted(results.keys())
489
+ cers = [results[s]["mean_cer"] for s in sc_list]
490
+ diversities = [results[s]["diversity"] for s in sc_list]
491
+
492
+ ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
493
+ ax1.set_xlabel("Guidance scale λ", fontsize=10)
494
+ ax1.set_ylabel("CER (↓ better)", fontsize=10)
495
+ ax1.set_title("Quality vs guidance scale", fontsize=10)
496
+
497
+ ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
498
+ ax2.set_xlabel("Guidance scale λ", fontsize=10)
499
+ ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
500
+ ax2.set_title("Diversity vs guidance scale", fontsize=10)
501
+
502
+ plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
503
+ plt.tight_layout()
504
+ path = os.path.join(output_dir, "guidance_scale_sweep.png")
505
+ plt.savefig(path, dpi=150, bbox_inches='tight')
506
+ plt.close()
507
+ print(f" Saved: {path}")
508
+ except ImportError:
509
+ pass
510
+
511
+ with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
512
+ json.dump({str(k): v for k, v in results.items()}, f, indent=2)
513
+
514
+ return results
reverse_process.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Fixed
3
+ ===========================
4
+ Two bugs fixed from the original:
5
+
6
+ BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
7
+ The model does q_sample(tgt, t) internally — so x_t got double-noised.
8
+ Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
9
+
10
+ BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
11
+ variance of each position's own distribution back to itself, which is
12
+ mathematically meaningless and just injects noise.
13
+ Fix: penalize tokens that are uniformly high-probability across ALL positions
14
+ (global common tokens). This genuinely promotes diversity.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class ReverseDiffusion:
22
+ def __init__(self, scheduler):
23
+ self.scheduler = scheduler
24
+
25
+ def p_sample_step(
26
+ self,
27
+ model,
28
+ x_t,
29
+ t,
30
+ condition,
31
+ beam_width=3,
32
+ temperature=1.0,
33
+ repetition_penalty=1.2,
34
+ diversity_penalty=0.3
35
+ ):
36
+ """
37
+ Single reverse step with temperature + penalties.
38
+ """
39
+
40
+ with torch.no_grad():
41
+
42
+ # ---- Shape safety ----
43
+ if x_t.dim() == 1:
44
+ x_t = x_t.unsqueeze(0)
45
+
46
+ if condition.dim() == 1:
47
+ condition = condition.unsqueeze(0)
48
+
49
+ if t.dim() == 0:
50
+ t = t.unsqueeze(0)
51
+
52
+ if t.shape[0] != x_t.shape[0]:
53
+ t = t.expand(x_t.shape[0])
54
+
55
+ # ---- Model forward ----
56
+ logits, _ = model(condition, x_t, t)
57
+
58
+ # ---- Temperature scaling ----
59
+ logits = logits / temperature
60
+
61
+ # ---- Repetition penalty (FIXED VERSION) ----
62
+ if repetition_penalty != 1.0:
63
+ logits = apply_repetition_penalty(
64
+ logits, x_t, repetition_penalty
65
+ )
66
+
67
+ # ---- Diversity penalty ----
68
+ if diversity_penalty > 0:
69
+ logits = apply_diversity_penalty(
70
+ logits, diversity_penalty
71
+ )
72
+
73
+ probs = F.softmax(logits, dim=-1)
74
+
75
+ B, L, V = probs.shape
76
+
77
+ # ---- Top-k beam expansion ----
78
+ topk_probs, topk_ids = torch.topk(
79
+ probs, beam_width, dim=-1
80
+ )
81
+
82
+ candidates = []
83
+
84
+ for k in range(beam_width):
85
+ next_tokens = topk_ids[:, :, k]
86
+ score = torch.log(
87
+ topk_probs[:, :, k] + 1e-9
88
+ ).sum()
89
+ candidates.append((next_tokens, score))
90
+
91
+ return candidates
92
+
93
+ def generate_beam(
94
+ self,
95
+ model,
96
+ condition,
97
+ beam_width=3,
98
+ num_steps=None,
99
+ temperature=1.0,
100
+ repetition_penalty=1.2,
101
+ diversity_penalty=0.3
102
+ ):
103
+ """
104
+ Beam-search reverse diffusion with temperature.
105
+ """
106
+
107
+ if num_steps is None:
108
+ num_steps = self.scheduler.num_timesteps
109
+
110
+ device = condition.device
111
+
112
+ if condition.dim() == 1:
113
+ condition = condition.unsqueeze(0)
114
+
115
+ B, L = condition.shape
116
+
117
+ # 🔥 Better initialization: start from MASK
118
+ x_init = torch.full(
119
+ (B, L),
120
+ fill_value=model.mask_token_id,
121
+ dtype=torch.long,
122
+ device=device
123
+ )
124
+
125
+ beams = [(x_init, 0.0)]
126
+
127
+ for step in reversed(range(num_steps)):
128
+
129
+ new_beams = []
130
+
131
+ for x_t, score in beams:
132
+
133
+ t_tensor = torch.full(
134
+ (B,),
135
+ step,
136
+ dtype=torch.long,
137
+ device=device
138
+ )
139
+
140
+ candidates = self.p_sample_step(
141
+ model,
142
+ x_t,
143
+ t_tensor,
144
+ condition,
145
+ beam_width,
146
+ temperature,
147
+ repetition_penalty,
148
+ diversity_penalty
149
+ )
150
+
151
+ for tokens, new_score in candidates:
152
+ new_beams.append(
153
+ (tokens, score + new_score)
154
+ )
155
+
156
+ # ---- Keep top beams ----
157
+ new_beams = sorted(
158
+ new_beams,
159
+ key=lambda x: x[1],
160
+ reverse=True
161
+ )
162
+
163
+ beams = new_beams[:beam_width]
164
+
165
+ best_tokens, best_score = beams[0]
166
+ return best_tokens
167
+
168
+
169
+
170
+ def generate(
171
+ self,
172
+ model,
173
+ condition,
174
+ num_steps=None,
175
+ temperature=0.8,
176
+ top_k=50,
177
+ repetition_penalty=1.2,
178
+ diversity_penalty=0.0,
179
+ ):
180
+ """
181
+ Correct D3PM iterative refinement.
182
+
183
+ x0_est starts as all [MASK].
184
+ Each step: forward(src=condition, tgt=x0_est, t)
185
+ → model applies q_sample(x0_est, t) internally
186
+ → predicts cleaner x0
187
+ → x0_est updated
188
+
189
+ diversity_penalty: reduces probability of tokens that are
190
+ globally dominant across all sequence positions (not logits.var()).
191
+ """
192
+ if num_steps is None:
193
+ num_steps = self.scheduler.num_timesteps
194
+
195
+ device = condition.device
196
+ if condition.dim() == 1:
197
+ condition = condition.unsqueeze(0)
198
+ B, L = condition.shape
199
+
200
+ T = self.scheduler.num_timesteps
201
+ step_size = max(1, T // num_steps)
202
+ timesteps = list(range(T - 1, -1, -step_size))
203
+ if timesteps[-1] != 0:
204
+ timesteps.append(0)
205
+
206
+ mask_id = model.mask_token_id
207
+ # Start: know nothing → all MASK is our initial clean estimate
208
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
209
+ hint = None
210
+
211
+ model.eval()
212
+ with torch.no_grad():
213
+ for step_idx, t_val in enumerate(timesteps):
214
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
215
+ is_last = (step_idx == len(timesteps) - 1)
216
+
217
+ # KEY: pass x0_est as tgt — model noises it internally
218
+ import inspect
219
+ sig = inspect.signature(model.forward).parameters
220
+ if 'x0_hint' in sig:
221
+ outputs = model(condition, x0_est, t, x0_hint=hint)
222
+ else:
223
+ outputs = model(condition, x0_est, t)
224
+
225
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
226
+
227
+ # Repetition penalty: down-weight tokens already in sequence
228
+ if repetition_penalty != 1.0:
229
+ logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
230
+
231
+ # Diversity penalty: reduce globally dominant tokens
232
+ if diversity_penalty > 0.0:
233
+ logits = apply_diversity_penalty(logits, diversity_penalty)
234
+
235
+ # Temperature + top-k
236
+ logits = logits / max(temperature, 1e-5)
237
+ if top_k > 0:
238
+ logits = top_k_filter(logits, top_k)
239
+
240
+ probs = F.softmax(logits, dim=-1)
241
+
242
+ if is_last:
243
+ x0_est = torch.argmax(probs, dim=-1)
244
+ else:
245
+ x0_est = batch_multinomial(probs)
246
+
247
+ hint = x0_est
248
+
249
+ return x0_est
250
+
251
+
252
+ # ── Penalty functions ─────────────────────────────────────────────────
253
+
254
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
255
+ """
256
+ Down-weight tokens that already appear in the current sequence.
257
+ Prevents मनो मनो मनो repetition loops.
258
+ penalty=1.0 → no effect
259
+ penalty=1.2 → mild suppression of repeated tokens
260
+ penalty=2.0 → strong suppression
261
+ """
262
+ B, L, V = logits.shape
263
+ for b in range(B):
264
+ for token_id in set(prev_tokens[b].tolist()):
265
+ if token_id > 4: # don't penalize special tokens
266
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
267
+ return logits
268
+
269
+
270
+ def apply_diversity_penalty(logits, penalty=0.5):
271
+ """
272
+ Correct diversity penalty: penalize tokens that are globally dominant
273
+ across ALL sequence positions. This forces the model to use less
274
+ common tokens, increasing output diversity.
275
+
276
+ Method: compute mean probability across positions, subtract penalty
277
+ times that mean. Tokens uniformly high everywhere get suppressed.
278
+
279
+ penalty=0.0 → no diversity enforcement
280
+ penalty=0.5 → moderate diversity
281
+ penalty=1.0 → strong diversity (may hurt coherence)
282
+ """
283
+ # Mean logit across all positions: [B, V]
284
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
285
+ # Subtract scaled global mean — suppresses globally common tokens
286
+ return logits - penalty * global_mean
287
+
288
+
289
+ def top_k_filter(logits, k):
290
+ B, L, V = logits.shape
291
+ if k >= V:
292
+ return logits
293
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
294
+ threshold = topk_vals[..., -1].unsqueeze(-1)
295
+ return logits.masked_fill(logits < threshold, float('-inf'))
296
+
297
+
298
+ def batch_multinomial(probs):
299
+ B, L, V = probs.shape
300
+ flat = probs.view(B * L, V) + 1e-9
301
+ flat = flat / flat.sum(dim=-1, keepdim=True)
302
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
reverse_process1.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ReverseDiffusion:
6
+ """
7
+ Stable reverse diffusion with:
8
+ - Beam search
9
+ - Self conditioning
10
+ - Temperature sampling
11
+ - Repetition penalty
12
+ - Diversity penalty
13
+ """
14
+
15
+ def __init__(self, scheduler):
16
+
17
+ self.scheduler = scheduler
18
+
19
+ self.temperature = 0.75
20
+ self.repetition_penalty = 1.15
21
+ self.diversity_penalty = 0.0
22
+ self.length_penalty = 1.0
23
+
24
+ # ------------------------------------------------
25
+ # penalties
26
+ # ------------------------------------------------
27
+
28
+ def apply_repetition_penalty(self, logits, tokens):
29
+
30
+ B, L, V = logits.shape
31
+
32
+ for b in range(B):
33
+
34
+ used = set(tokens[b].tolist())
35
+
36
+ for token_id in used:
37
+ logits[b, :, token_id] /= self.repetition_penalty
38
+
39
+ return logits
40
+
41
+ def apply_diversity_penalty(self, logits):
42
+
43
+ if self.diversity_penalty == 0:
44
+ return logits
45
+
46
+ logits_var = logits.var(dim=-1, keepdim=True)
47
+ return logits + self.diversity_penalty * logits_var
48
+
49
+ # ------------------------------------------------
50
+ # single reverse step
51
+ # ------------------------------------------------
52
+
53
+ def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
54
+
55
+ with torch.no_grad():
56
+
57
+ logits, hidden = model(condition, x_t, t, self_cond)
58
+
59
+ logits = logits / self.temperature
60
+
61
+ logits = self.apply_repetition_penalty(logits, x_t)
62
+ logits = self.apply_diversity_penalty(logits)
63
+
64
+ probs = F.softmax(logits, dim=-1)
65
+
66
+ B, L, V = probs.shape
67
+
68
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
69
+
70
+ candidates = []
71
+
72
+ for k in range(beam_width):
73
+
74
+ tokens = topk_ids[:, :, k]
75
+
76
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
77
+
78
+ candidates.append((tokens, score))
79
+
80
+ return candidates
81
+
82
+ # ------------------------------------------------
83
+ # beam reverse diffusion
84
+ # ------------------------------------------------
85
+
86
+ def generate_beam(self, model, condition, beam_width=3, num_steps=None):
87
+
88
+ if num_steps is None:
89
+ num_steps = self.scheduler.num_timesteps
90
+
91
+ device = condition.device
92
+
93
+ if condition.dim() == 1:
94
+ condition = condition.unsqueeze(0)
95
+
96
+ B, L = condition.shape
97
+
98
+ # ------------------------------------------------
99
+ # BETTER LATENT INITIALIZATION
100
+ # ------------------------------------------------
101
+
102
+ x_init = condition.clone()
103
+
104
+ mask = torch.rand_like(x_init.float()) < 0.5
105
+ x_init[mask] = model.mask_token_id
106
+
107
+ beams = [(x_init, 0.0)]
108
+
109
+ self_cond = None
110
+
111
+ for step in reversed(range(num_steps)):
112
+
113
+ new_beams = []
114
+
115
+ for x_t, score in beams:
116
+
117
+ t_tensor = torch.full(
118
+ (B,),
119
+ step,
120
+ dtype=torch.long,
121
+ device=device
122
+ )
123
+
124
+ candidates = self.p_sample_step(
125
+ model,
126
+ x_t,
127
+ t_tensor,
128
+ condition,
129
+ self_cond,
130
+ beam_width
131
+ )
132
+
133
+ for tokens, new_score in candidates:
134
+
135
+ length_norm = tokens.shape[1] ** self.length_penalty
136
+
137
+ final_score = (score + new_score) / length_norm
138
+
139
+ new_beams.append((tokens, final_score))
140
+
141
+ new_beams = sorted(
142
+ new_beams,
143
+ key=lambda x: x[1],
144
+ reverse=True
145
+ )
146
+
147
+ beams = new_beams[:beam_width]
148
+
149
+ # self conditioning
150
+ self_cond = beams[0][0]
151
+
152
+ best_tokens, best_score = beams[0]
153
+
154
+ return best_tokens
reverse_process2.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reverse_process.py — Final Correct Version
3
+ =============================================
4
+
5
+ KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
6
+ in inference.py, which is what produced BERTScore 0.75 at validation.
7
+
8
+ CRITICAL BUG IN PREVIOUS VERSION:
9
+ We passed inference_mode=True to the model, but the model was NEVER
10
+ called with inference_mode=True during training or validation.
11
+ run_inference() (the validated path) does:
12
+ model(input_ids, x0_est, t, x0_hint=hint)
13
+ → inference_mode defaults to False.
14
+
15
+ With inference_mode=True the model does two things differently:
16
+ 1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
17
+ 2. Skips q_sample at t=0 (training always called q_sample)
18
+ The model was never trained to handle these conditions → garbage output.
19
+
20
+ Fix: do NOT pass inference_mode. Let it default to False, exactly
21
+ as run_inference() did.
22
+
23
+ BUGS FIXED (vs original reverse_process.py)
24
+ --------------------------------------------
25
+ BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
26
+ Use generate() (iterative refinement) from app1.py instead.
27
+ BUG 2 apply_diversity_penalty used logits.var() → noise injection.
28
+ Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
29
+ BUG 3 x0_hint (self-conditioning) never passed to model.
30
+ Fixed: generate() passes x0_hint=hint every step.
31
+ BUG 4 params not forwarded from generate_beam() to p_sample_step().
32
+ Fixed in generate_beam() (kept for reference, not for production use).
33
+ """
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+
39
+ class ReverseDiffusion:
40
+
41
+ def __init__(self, scheduler):
42
+ self.scheduler = scheduler
43
+
44
+ # Attribute-style defaults for backward compat with any code
45
+ # that sets reverse_diffusion.temperature = 0.9 etc.
46
+ # generate() prefers explicit kwargs and falls back to these.
47
+ self.temperature = 0.75
48
+ self.repetition_penalty = 1.15
49
+ self.diversity_penalty = 0.0
50
+ self.top_k = 50
51
+
52
+ # ------------------------------------------------------------------ #
53
+ # generate — CORRECT D3PM iterative refinement #
54
+ # Exact equivalent of run_inference() in inference.py #
55
+ # ------------------------------------------------------------------ #
56
+ def generate(
57
+ self,
58
+ model,
59
+ condition,
60
+ num_steps = None,
61
+ temperature = None,
62
+ top_k = None,
63
+ repetition_penalty = None,
64
+ diversity_penalty = None,
65
+ ):
66
+ """
67
+ D3PM iterative refinement — identical to run_inference() in inference.py,
68
+ which is the validated path (BERTScore 0.75).
69
+
70
+ Algorithm:
71
+ x0_est = all [MASK]
72
+ for t = T-1 down to 0:
73
+ logits = model(src, x0_est, t, x0_hint=hint)
74
+ ↑ inference_mode NOT passed (defaults to False)
75
+ ↑ this exactly matches training/validation
76
+ apply penalties, temperature, top_k
77
+ if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
78
+ if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
79
+ hint = x0_est
80
+ """
81
+ # Resolve: explicit kwarg > object attribute
82
+ temperature = temperature if temperature is not None else self.temperature
83
+ top_k = top_k if top_k is not None else self.top_k
84
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
85
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
86
+
87
+ if num_steps is None:
88
+ num_steps = self.scheduler.num_timesteps
89
+
90
+ device = condition.device
91
+ if condition.dim() == 1:
92
+ condition = condition.unsqueeze(0)
93
+ B, L = condition.shape
94
+
95
+ T = self.scheduler.num_timesteps
96
+ step_size = max(1, T // num_steps)
97
+ timesteps = list(range(T - 1, -1, -step_size))
98
+ if timesteps[-1] != 0:
99
+ timesteps.append(0)
100
+
101
+ mask_id = model.mask_token_id
102
+ x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
103
+ hint = None
104
+
105
+ model.eval()
106
+ with torch.no_grad():
107
+ for step_idx, t_val in enumerate(timesteps):
108
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
109
+ is_last = (step_idx == len(timesteps) - 1)
110
+
111
+ # ── CRITICAL: do NOT pass inference_mode ──────────────────
112
+ # inference_mode defaults to False inside SanskritModel /
113
+ # D3PMCrossAttention. This matches run_inference() exactly.
114
+ # Passing inference_mode=True changes tgt_pad_mask and
115
+ # q_sample behaviour — the model was never trained for that.
116
+ logits, _ = model(condition, x0_est, t, x0_hint=hint)
117
+
118
+ # Repetition penalty
119
+ if repetition_penalty != 1.0:
120
+ logits = apply_repetition_penalty(
121
+ logits, x0_est, repetition_penalty
122
+ )
123
+
124
+ # Diversity penalty (correct: global mean suppression)
125
+ if diversity_penalty > 0.0:
126
+ logits = apply_diversity_penalty(logits, diversity_penalty)
127
+
128
+ logits = logits / max(temperature, 1e-5)
129
+
130
+ if top_k > 0:
131
+ logits = top_k_filter(logits, top_k)
132
+
133
+ probs = F.softmax(logits, dim=-1)
134
+
135
+ # Stochastic at every step except the last (argmax at t=0)
136
+ if is_last:
137
+ x0_est = torch.argmax(probs, dim=-1)
138
+ else:
139
+ x0_est = batch_multinomial(probs)
140
+
141
+ hint = x0_est
142
+
143
+ return x0_est # (B, L)
144
+
145
+ # ------------------------------------------------------------------ #
146
+ # p_sample_step — used by generate_beam (not for production) #
147
+ # ------------------------------------------------------------------ #
148
+ def p_sample_step(
149
+ self,
150
+ model,
151
+ x_t,
152
+ t,
153
+ condition,
154
+ beam_width = 3,
155
+ temperature = 1.0,
156
+ repetition_penalty = 1.2,
157
+ diversity_penalty = 0.3,
158
+ x0_hint = None,
159
+ ):
160
+ with torch.no_grad():
161
+ if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
162
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
163
+ if t.dim() == 0: t = t.unsqueeze(0)
164
+ if t.shape[0] != x_t.shape[0]:
165
+ t = t.expand(x_t.shape[0])
166
+
167
+ # No inference_mode — matches training convention
168
+ logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
169
+
170
+ logits = logits / max(temperature, 1e-5)
171
+
172
+ if repetition_penalty != 1.0:
173
+ logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
174
+ if diversity_penalty > 0.0:
175
+ logits = apply_diversity_penalty(logits, diversity_penalty)
176
+
177
+ probs = F.softmax(logits, dim=-1)
178
+ B, L, V = probs.shape
179
+
180
+ topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
181
+ candidates = []
182
+ for k in range(beam_width):
183
+ next_tokens = topk_ids[:, :, k]
184
+ score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
185
+ candidates.append((next_tokens, score))
186
+ return candidates
187
+
188
+ # ------------------------------------------------------------------ #
189
+ # generate_beam — kept for reference; NOT the correct D3PM method #
190
+ # ------------------------------------------------------------------ #
191
+ def generate_beam(
192
+ self,
193
+ model,
194
+ condition,
195
+ beam_width = 3,
196
+ num_steps = None,
197
+ temperature = None,
198
+ repetition_penalty = None,
199
+ diversity_penalty = None,
200
+ ):
201
+ """
202
+ WARNING: do NOT call this from app1.py for D3PM generation.
203
+ generate_beam() forces every position to the same top-k token
204
+ → all-Ṛ / all-rud repetition. Use generate() instead.
205
+ Kept only for experimental reference.
206
+ """
207
+ temperature = temperature if temperature is not None else self.temperature
208
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
209
+ diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
210
+ if num_steps is None:
211
+ num_steps = self.scheduler.num_timesteps
212
+
213
+ device = condition.device
214
+ if condition.dim() == 1: condition = condition.unsqueeze(0)
215
+ B, L = condition.shape
216
+
217
+ x_init = torch.full((B, L), fill_value=model.mask_token_id,
218
+ dtype=torch.long, device=device)
219
+ beams = [(x_init, 0.0)]
220
+ best_hint = None
221
+
222
+ for step in reversed(range(num_steps)):
223
+ t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
224
+ new_beams = []
225
+ for x_t, score in beams:
226
+ candidates = self.p_sample_step(
227
+ model, x_t, t_tensor, condition,
228
+ beam_width = beam_width,
229
+ temperature = temperature,
230
+ repetition_penalty = repetition_penalty,
231
+ diversity_penalty = diversity_penalty,
232
+ x0_hint = best_hint,
233
+ )
234
+ for tokens, new_score in candidates:
235
+ new_beams.append((tokens, score + new_score.item()))
236
+
237
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
238
+ beams = new_beams[:beam_width]
239
+ best_hint = beams[0][0]
240
+
241
+ return beams[0][0] # (B, L)
242
+
243
+
244
+ # ── Penalty helpers ────────────────────────────────────────────────────────
245
+
246
+ def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
247
+ """Down-weight tokens already present in the sequence."""
248
+ for b in range(logits.shape[0]):
249
+ for token_id in set(prev_tokens[b].tolist()):
250
+ if token_id > 4:
251
+ logits[b, :, token_id] = logits[b, :, token_id] / penalty
252
+ return logits
253
+
254
+
255
+ def apply_diversity_penalty(logits, penalty=0.3):
256
+ """
257
+ Correct diversity penalty: suppress globally dominant tokens.
258
+ logits -= penalty * mean(logits, dim=1) [sequence dimension]
259
+ """
260
+ global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
261
+ return logits - penalty * global_mean
262
+
263
+
264
+ def top_k_filter(logits, k):
265
+ B, L, V = logits.shape
266
+ if k >= V: return logits
267
+ topk_vals, _ = torch.topk(logits, k, dim=-1)
268
+ return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
269
+
270
+
271
+ def batch_multinomial(probs):
272
+ B, L, V = probs.shape
273
+ flat = probs.view(B * L, V) + 1e-9
274
+ flat = flat / flat.sum(dim=-1, keepdim=True)
275
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
run_analysis.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/run_analysis.py
3
+ =========================
4
+ Entry point for all 5 tasks.
5
+
6
+ Tasks:
7
+ Task 1 — KV Cache benchmark (no retraining)
8
+ Task 2 — Attention viz + drift (no retraining)
9
+ Task 3 — Concept vectors + PCA steer (no retraining)
10
+ Task 4 — Step ablation (REQUIRES retraining for each T)
11
+ Task 5 — Classifier-free guidance (trains small 10k-param classifier)
12
+
13
+ Usage:
14
+ python analysis/run_analysis.py --task 1
15
+ python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
16
+ python analysis/run_analysis.py --task 3
17
+ python analysis/run_analysis.py --task 4 --phase generate_configs
18
+ python analysis/run_analysis.py --task 4 --phase analyze
19
+ python analysis/run_analysis.py --task 5
20
+ python analysis/run_analysis.py --task all --input "satyameva jayate"
21
+
22
+ Output files: analysis/outputs/
23
+ """
24
+
25
+ import torch
26
+ import os, sys, argparse, json
27
+ import numpy as np
28
+
29
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
30
+ from config import CONFIG
31
+ from inference import load_model
32
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
33
+
34
+ OUTPUT_DIR = "analysis/outputs"
35
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
36
+
37
+
38
+ # ── Shared loader ─────────────────────────────────────────────────────
39
+
40
+ def load_everything(cfg, device):
41
+ model_name = cfg['model_type']
42
+ has_neg = cfg['data']['include_negative_examples']
43
+ ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
44
+ if not os.path.exists(ckpt):
45
+ raise FileNotFoundError(f"No checkpoint at {ckpt}. Train first.")
46
+ model, cfg = load_model(ckpt, cfg, device)
47
+ model.eval()
48
+ src_tok = SanskritSourceTokenizer(
49
+ vocab_size=cfg['model'].get('src_vocab_size', 500),
50
+ max_len=cfg['model']['max_seq_len'])
51
+ tgt_tok = SanskritTargetTokenizer(
52
+ vocab_size=cfg['model'].get('tgt_vocab_size', 500),
53
+ max_len=cfg['model']['max_seq_len'])
54
+ return model, src_tok, tgt_tok, cfg
55
+
56
+
57
+ def load_val_data(cfg, src_tok, tgt_tok, n=500):
58
+ """Load validation set as (src_tensors, ref_strings, input_strings)."""
59
+ from Data.data import OptimizedSanskritDataset
60
+ from torch.utils.data import Subset
61
+ from sklearn.model_selection import train_test_split
62
+
63
+ dataset = OptimizedSanskritDataset(
64
+ 'train', max_len=cfg['model']['max_seq_len'],
65
+ cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
66
+ total = min(cfg['data']['dataset_size'], len(dataset))
67
+ _, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
68
+ val_idx = val_idx[:n]
69
+
70
+ src_list, ref_list, inp_list = [], [], []
71
+ for i in val_idx:
72
+ item = dataset[i]
73
+ src_list.append(item['input_ids'].unsqueeze(0))
74
+ ref_list.append(item['target_text'])
75
+ inp_list.append(item['input_text'])
76
+ return src_list, ref_list, inp_list
77
+
78
+
79
+ # ── Task 1 ────────────────────────────────────────────────────────────
80
+
81
+ def run_task1(model, src_tok, device):
82
+ print("\n" + "="*65)
83
+ print(" TASK 1 — KV Cache Benchmark")
84
+ print("="*65)
85
+ if not hasattr(model.model, 'generate_cached'):
86
+ print(" SKIP: not D3PMCrossAttention.")
87
+ return
88
+ from analysis.kv_cache_benchmark import run_benchmark, print_summary
89
+ results = run_benchmark(model, src_tok, device, src_lens=[16, 32, 64])
90
+ print_summary(results)
91
+ path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
92
+ with open(path, "w") as f:
93
+ f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
94
+ f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
95
+ f"{'speedup':>8} {'encoder%':>9}\n")
96
+ for src_len, r in results.items():
97
+ f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
98
+ f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
99
+ print(f" Saved: {path}")
100
+
101
+
102
+ # ── Task 2 ────────────────────────────────────────────────────────────
103
+
104
+ def run_task2(model, src_tok, tgt_tok, device, input_text):
105
+ print("\n" + "="*65)
106
+ print(" TASK 2 — Attention Visualization + Semantic Drift")
107
+ print("="*65)
108
+ print(f" Input: {input_text}")
109
+ if not hasattr(model.model, 'encode_source'):
110
+ print(" SKIP: not D3PMCrossAttention.")
111
+ return
112
+
113
+ src_ids = src_tok.encode(input_text)
114
+ src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
115
+ src_chars = list(input_text.strip())
116
+
117
+ from analysis.attention_viz import (AttentionCapture, plot_attn_heatmap,
118
+ plot_attn_evolution, plot_all_layers)
119
+ from analysis.semantic_drift import (capture_intermediate_outputs,
120
+ compute_drift, compute_token_stability,
121
+ plot_drift_curve)
122
+
123
+ # Attention capture
124
+ print(" Capturing attention weights...")
125
+ capturer = AttentionCapture(model)
126
+ step_weights = capturer.capture(src_tensor, capture_every=10)
127
+
128
+ with torch.no_grad():
129
+ out_ids = model.generate_cached(src_tensor)
130
+ tgt_ids = [x for x in out_ids[0].tolist() if x > 4]
131
+ tgt_text = tgt_tok.decode(tgt_ids).strip()
132
+ tgt_chars = list(tgt_text)
133
+ print(f" Output: {tgt_text}")
134
+
135
+ first_t = max(step_weights.keys())
136
+ plot_attn_heatmap(step_weights, t_val=first_t, layer=0,
137
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
138
+ save_path=os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"),
139
+ title=f"Attention t={first_t} (noisy) Layer 0")
140
+ plot_attn_heatmap(step_weights, t_val=0, layer=0,
141
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
142
+ save_path=os.path.join(OUTPUT_DIR, "task2_attn_t0.png"),
143
+ title="Attention t=0 (final) Layer 0")
144
+ plot_all_layers(step_weights, t_val=0,
145
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
146
+ save_path=os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"))
147
+ if len(src_chars) > 0 and len(tgt_chars) > 0:
148
+ plot_attn_evolution(step_weights, src_token_idx=0, tgt_token_idx=0,
149
+ layer=0, src_token_str=src_chars[0], tgt_token_str=tgt_chars[0],
150
+ save_path=os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"))
151
+
152
+ # Semantic drift
153
+ print(" Computing semantic drift...")
154
+ step_outputs, final_out = capture_intermediate_outputs(
155
+ model, src_tensor, tgt_tok, capture_every=5)
156
+ drift = compute_drift(step_outputs, final_out)
157
+ stab = compute_token_stability(step_outputs, final_out, tgt_tok)
158
+ plot_drift_curve(drift, src_text=input_text,
159
+ save_path=os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"))
160
+
161
+ print(f" Lock-in timestep: t={drift['lock_in_t']}")
162
+ print(f" Mean position lock-in: t={stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}")
163
+
164
+ report = os.path.join(OUTPUT_DIR, "task2_report.txt")
165
+ with open(report, "w", encoding="utf-8") as f:
166
+ f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "="*50 + "\n\n")
167
+ f.write(f"Input : {input_text}\nOutput : {final_out}\n\n")
168
+ f.write(f"Lock-in t : {drift['lock_in_t']}\n")
169
+ f.write(f"Mean pos lock-in : {stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}\n\n")
170
+ f.write("Step → Output → CER-to-final\n" + "-"*60 + "\n")
171
+ for tv, cer in zip(drift["t_vals"], drift["cer_to_final"]):
172
+ f.write(f" t={tv:4d} | {step_outputs.get(tv,'')[:40]:40s} | {cer:.4f}\n")
173
+ print(f" Report: {report}")
174
+
175
+
176
+ # ── Task 3 ────────────────────────────────────────────────────────────
177
+
178
+ def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list):
179
+ print("\n" + "="*65)
180
+ print(" TASK 3 — Concept Vectors + PCA Steering")
181
+ print("="*65)
182
+ if not hasattr(model.model, 'encode_source'):
183
+ print(" SKIP: not D3PMCrossAttention.")
184
+ return
185
+
186
+ from analysis.concept_vectors import (collect_hidden_states, fit_pca,
187
+ find_diversity_direction, generate_diversity_spectrum, plot_pca_space)
188
+
189
+ # Collect hidden states from val set
190
+ n = min(500, len(src_list))
191
+ print(f" Collecting hidden states from {n} examples...")
192
+ hidden, _ = collect_hidden_states(
193
+ model, src_list[:n], t_capture=0, max_samples=n)
194
+
195
+ # Compute output lengths for diversity direction
196
+ lengths = []
197
+ for src in src_list[:n]:
198
+ with torch.no_grad():
199
+ out = model.generate_cached(src.to(device))
200
+ ids = [x for x in out[0].tolist() if x > 4]
201
+ lengths.append(len(tgt_tok.decode(ids)))
202
+
203
+ # Fit PCA + find diversity direction
204
+ pca = fit_pca(hidden, n_components=min(50, n-1))
205
+ direction, best_pc, corr = find_diversity_direction(hidden, lengths, pca)
206
+
207
+ # Plot concept space
208
+ plot_pca_space(hidden, lengths, pca, best_pc,
209
+ save_path=os.path.join(OUTPUT_DIR, "task3_concept_space.png"))
210
+
211
+ # Generate diversity spectrum for first example
212
+ print("\n Diversity spectrum for first example:")
213
+ src0 = src_list[0]
214
+ inp0 = src_tok.decode([x for x in src0[0].tolist() if x > 4])
215
+ print(f" Input: {inp0}")
216
+ spectrum = generate_diversity_spectrum(
217
+ model, src0.to(device), direction, tgt_tok,
218
+ alphas=[-2.0, -1.0, 0.0, 1.0, 2.0])
219
+
220
+ # Save diversity direction + results
221
+ np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
222
+
223
+ report = os.path.join(OUTPUT_DIR, "task3_report.txt")
224
+ with open(report, "w", encoding="utf-8") as f:
225
+ f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
226
+ f.write(f"PCA: {pca.n_components_} components, "
227
+ f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
228
+ f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with output length)\n\n")
229
+ f.write("Diversity spectrum:\n")
230
+ for alpha, text in sorted(spectrum.items()):
231
+ f.write(f" alpha={alpha:+.1f} → {text}\n")
232
+ print(f" Report: {report}")
233
+
234
+
235
+ # ── Task 4 ────────────────────────────────────────────────────────────
236
+
237
+ def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
238
+ src_list, ref_list):
239
+ print("\n" + "="*65)
240
+ print(f" TASK 4 — Step Ablation (phase={phase})")
241
+ print("="*65)
242
+
243
+ from analysis.step_ablation import (generate_ablation_configs,
244
+ run_ablation_analysis, plot_ablation_3d, run_adversarial_test)
245
+
246
+ if phase == "generate_configs":
247
+ print(" Generating ablation configs...")
248
+ generate_ablation_configs(output_dir="ablation_configs")
249
+ print("\n NEXT STEPS:")
250
+ print(" 1. bash ablation_configs/train_all.sh")
251
+ print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
252
+
253
+ elif phase == "analyze":
254
+ # Check which models exist
255
+ existing = [T for T in [4, 8, 16, 32, 64]
256
+ if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
257
+ if not existing:
258
+ print(" No ablation models found at ablation_results/T*/best_model.pt")
259
+ print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
260
+ print(" Then: bash ablation_configs/train_all.sh")
261
+ return
262
+
263
+ print(f" Found models for T={existing}")
264
+ results = run_ablation_analysis(
265
+ ablation_dir="ablation_results", base_cfg=cfg,
266
+ src_list=src_list[:200], ref_list=ref_list[:200],
267
+ tgt_tokenizer=tgt_tok, device=device,
268
+ output_dir=OUTPUT_DIR)
269
+ plot_ablation_3d(results,
270
+ save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
271
+
272
+ # Adversarial robustness always runs on existing model (no retraining)
273
+ print("\n Running adversarial robustness test...")
274
+ inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
275
+ for s in src_list[:50]]
276
+ run_adversarial_test(
277
+ model, src_tok, tgt_tok,
278
+ test_inputs=inp_texts, test_refs=ref_list[:50],
279
+ device=device, output_dir=OUTPUT_DIR)
280
+
281
+
282
+ # ── Task 5 ────────────────────────────────────────────────────────────
283
+
284
+ def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list):
285
+ print("\n" + "="*65)
286
+ print(" TASK 5 — Classifier-Free Guidance")
287
+ print("="*65)
288
+ if not hasattr(model.model, 'encode_source'):
289
+ print(" SKIP: not D3PMCrossAttention.")
290
+ return
291
+
292
+ from analysis.quality_classifier import (
293
+ QualityClassifier, collect_quality_data,
294
+ train_quality_classifier, sweep_guidance_scales)
295
+
296
+ clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
297
+ d_model = cfg['model']['d_model']
298
+
299
+ # Step 1: collect or load training data
300
+ data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
301
+ if os.path.exists(data_path):
302
+ print(" Loading cached quality data...")
303
+ data = np.load(data_path)
304
+ hidden = data["hidden"]
305
+ quality = data["quality"]
306
+ else:
307
+ print(" Collecting quality data (this takes a few minutes)...")
308
+ n = min(2000, len(src_list))
309
+ hidden, quality = collect_quality_data(
310
+ model, src_list[:n], ref_list[:n], tgt_tok,
311
+ t_capture=0, max_samples=n)
312
+ np.savez(data_path, hidden=hidden, quality=quality)
313
+ print(f" Saved quality data: {data_path}")
314
+
315
+ # Step 2: train or load classifier
316
+ if os.path.exists(clf_path):
317
+ print(f" Loading cached classifier: {clf_path}")
318
+ clf = QualityClassifier(d_model)
319
+ clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
320
+ clf.eval()
321
+ else:
322
+ print(" Training quality classifier...")
323
+ clf = train_quality_classifier(
324
+ hidden, quality, d_model=d_model,
325
+ epochs=30, batch_size=64, lr=1e-3,
326
+ save_path=clf_path)
327
+ clf.eval()
328
+
329
+ # Step 3: guidance scale sweep
330
+ print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
331
+ n_sweep = min(50, len(src_list))
332
+ results = sweep_guidance_scales(
333
+ model, clf, src_list[:n_sweep], ref_list[:n_sweep],
334
+ tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
335
+ n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
336
+
337
+ # Find optimal scale
338
+ best_scale = min(results, key=lambda s: results[s]["mean_cer"])
339
+ print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
340
+ f"CER={results[best_scale]['mean_cer']:.4f}")
341
+
342
+ report = os.path.join(OUTPUT_DIR, "task5_report.txt")
343
+ with open(report, "w") as f:
344
+ f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
345
+ f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
346
+ f.write(f"Training samples : {len(hidden)}\n\n")
347
+ f.write("Guidance scale sweep:\n")
348
+ f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10}\n")
349
+ f.write(" " + "-"*28 + "\n")
350
+ for s in sorted(results.keys()):
351
+ r = results[s]
352
+ marker = " ← optimal" if s == best_scale else ""
353
+ f.write(f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f}{marker}\n")
354
+ print(f" Report: {report}")
355
+
356
+
357
+ # ── Main ──────────────────────────────────────────────────────────────
358
+
359
+ def main():
360
+ parser = argparse.ArgumentParser()
361
+ parser.add_argument("--task",
362
+ choices=["1","2","3","4","5","all"], default="all")
363
+ parser.add_argument("--input",
364
+ default="dharmo rakṣati rakṣitaḥ",
365
+ help="IAST input text for Task 2")
366
+ parser.add_argument("--phase",
367
+ choices=["generate_configs", "analyze"], default="analyze",
368
+ help="Task 4 phase: generate_configs (before training) or analyze (after)")
369
+ args = parser.parse_args()
370
+
371
+ cfg = CONFIG
372
+ device = torch.device(cfg['training']['device'])
373
+
374
+ print("Loading model and tokenizers...")
375
+ model, src_tok, tgt_tok, cfg = load_everything(cfg, device)
376
+
377
+ # Load val data for tasks that need it (Tasks 3, 4, 5)
378
+ needs_data = args.task in ("3", "4", "5", "all")
379
+ if needs_data:
380
+ print("Loading validation data...")
381
+ src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
382
+ else:
383
+ src_list, ref_list, inp_list = [], [], []
384
+
385
+ tasks = (["1","2","3","4","5"] if args.task == "all"
386
+ else [args.task])
387
+
388
+ for task in tasks:
389
+ if task == "1":
390
+ run_task1(model, src_tok, device)
391
+ elif task == "2":
392
+ run_task2(model, src_tok, tgt_tok, device, args.input)
393
+ elif task == "3":
394
+ run_task3(model, src_tok, tgt_tok, device, src_list, ref_list)
395
+ elif task == "4":
396
+ run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
397
+ src_list, ref_list)
398
+ elif task == "5":
399
+ run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list)
400
+
401
+ print(f"\n{'='*65}")
402
+ print(f" All outputs saved to: {OUTPUT_DIR}/")
403
+ print("="*65)
404
+
405
+
406
+ if __name__ == "__main__":
407
+ main()
sanskrit_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ sanskrit_model.py — Fixed
3
+ ===========================
4
+ Added inference_mode parameter to forward() so reverse_process.py can
5
+ pass inference_mode=True without a TypeError.
6
+
7
+ The wrapper introspects each inner model's signature and only passes
8
+ kwargs that model actually accepts — safe across all four architectures.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import inspect
14
+
15
+
16
+ class SanskritModel(nn.Module):
17
+ def __init__(self, cfg):
18
+ super().__init__()
19
+ model_type = cfg['model_type']
20
+
21
+ if model_type == 'd3pm_cross_attention':
22
+ from model.d3pm_model_cross_attention import D3PMCrossAttention
23
+ self.model = D3PMCrossAttention(cfg)
24
+
25
+ elif model_type == 'd3pm_encoder_decoder':
26
+ from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
27
+ self.model = D3PMEncoderDecoder(cfg)
28
+
29
+ elif model_type == 'baseline_cross_attention':
30
+ from model.d3pm_model_cross_attention import BaselineCrossAttention
31
+ self.model = BaselineCrossAttention(cfg)
32
+
33
+ elif model_type == 'baseline_encoder_decoder':
34
+ from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
35
+ self.model = BaselineEncoderDecoder(cfg)
36
+
37
+ else:
38
+ raise ValueError(f"Unknown model_type: {model_type}")
39
+
40
+ def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
41
+ """
42
+ Forward pass. Introspects the inner model's signature so only
43
+ supported kwargs are passed — works with all four architectures.
44
+ """
45
+ sig = inspect.signature(self.model.forward).parameters
46
+ kwargs = {}
47
+ if 'x0_hint' in sig:
48
+ kwargs['x0_hint'] = x0_hint
49
+ if 'inference_mode' in sig:
50
+ kwargs['inference_mode'] = inference_mode
51
+
52
+ if 't' in sig:
53
+ return self.model(input_ids, target_ids, t, **kwargs)
54
+ else:
55
+ return self.model(input_ids, target_ids, **kwargs)
56
+
57
+ @torch.no_grad()
58
+ def generate(self, src, **kwargs):
59
+ sig = inspect.signature(self.model.generate).parameters
60
+ filtered = {k: v for k, v in kwargs.items() if k in sig}
61
+ return self.model.generate(src, **filtered)
scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scheduler.py — Fixed & Upgraded
3
+ ==================================
4
+ Changes:
5
+ 1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
6
+ 2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
7
+ 3. sample_timestep samples [0, T-1] including t=0, so model trains on
8
+ fully-clean inputs (learns the identity at t=0 explicitly).
9
+ """
10
+ import torch, math
11
+
12
+ class OptimizedCosineScheduler:
13
+ def __init__(self, cfg, device=None):
14
+ self.num_timesteps = cfg['model']['diffusion_steps'] # 64
15
+ self.mask_token_id = cfg['diffusion']['mask_token_id']
16
+ self.device = device or torch.device('cpu')
17
+ self.alphas_cumprod = self._build_schedule().to(self.device)
18
+
19
+ def _build_schedule(self):
20
+ T = self.num_timesteps
21
+ t = torch.arange(T + 1, dtype=torch.float32)
22
+ f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
23
+ alphas_bar = f_t / f_t[0]
24
+ alphas_bar = alphas_bar[1:] # shape [T]
25
+ alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
26
+ alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
27
+ return alphas_bar
28
+
29
+ def sample_timestep(self, batch_size):
30
+ """Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
31
+ return torch.randint(0, self.num_timesteps, (batch_size,))
32
+
33
+ def get_alpha(self, t):
34
+ return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
semantic_drift.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/semantic_drift.py
3
+ ===========================
4
+ Task 2: Semantic drift metric — how much does the intermediate generation
5
+ diverge from the final output as we walk through diffusion steps T → 0?
6
+
7
+ Metric: CER between x0_estimate at each step vs the final x0 at t=0.
8
+
9
+ A well-trained model should show:
10
+ - High drift at t=T-1 (near-random initial estimate)
11
+ - Rapid decrease in drift around t=T//2 (model finds the right structure)
12
+ - Near-zero drift at t=10 (output is stable, only fine corrections remain)
13
+
14
+ If drift stays high until t=5 then suddenly collapses → model is doing all
15
+ its work in the last few steps → consider reducing T.
16
+
17
+ Also measures:
18
+ - Token stability: fraction of positions that don't change between steps
19
+ - Lock-in time: first step where each position "commits" to its final token
20
+
21
+ No retraining required. Uses generate_cached() with intermediate snapshots.
22
+ """
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import numpy as np
27
+ from typing import List, Dict, Optional, Tuple
28
+
29
+
30
+ def compute_cer_between(pred: str, ref: str) -> float:
31
+ """CER between two strings."""
32
+ if not ref:
33
+ return 1.0 if pred else 0.0
34
+
35
+ def edit_distance(s1, s2):
36
+ m, n = len(s1), len(s2)
37
+ dp = list(range(n + 1))
38
+ for i in range(1, m + 1):
39
+ prev, dp[0] = dp[0], i
40
+ for j in range(1, n + 1):
41
+ temp = dp[j]
42
+ dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
43
+ prev = temp
44
+ return dp[n]
45
+
46
+ return edit_distance(pred, ref) / len(ref)
47
+
48
+
49
+ @torch.no_grad()
50
+ def capture_intermediate_outputs(
51
+ model,
52
+ src: torch.Tensor,
53
+ tgt_tokenizer,
54
+ capture_every: int = 5,
55
+ temperature: float = 0.8,
56
+ top_k: int = 40,
57
+ ) -> Tuple[Dict[int, str], str]:
58
+ """
59
+ Run generation while recording the decoded x0_estimate at every
60
+ `capture_every` diffusion steps.
61
+
62
+ Args:
63
+ model : SanskritModel (D3PMCrossAttention)
64
+ src : [1, src_len] IAST token ids (single sample)
65
+ tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs
66
+ capture_every : record every N steps
67
+ temperature : sampling temperature
68
+ top_k : top-k filter
69
+
70
+ Returns:
71
+ step_outputs : dict mapping t_val → decoded Devanagari string at that step
72
+ final_output : decoded string at t=0 (final result)
73
+ """
74
+ if src.dim() == 1:
75
+ src = src.unsqueeze(0)
76
+
77
+ inner = model.model
78
+ T = inner.scheduler.num_timesteps
79
+ device = src.device
80
+
81
+ # Encode source once (KV cache)
82
+ memory, src_pad_mask = inner.encode_source(src)
83
+
84
+ B = src.shape[0]
85
+ tgt_len = inner.max_seq_len
86
+ mask_id = inner.mask_token_id
87
+
88
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
89
+ hint = None
90
+
91
+ step_outputs: Dict[int, str] = {}
92
+ inner.eval()
93
+
94
+ for t_val in range(T - 1, -1, -1):
95
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
96
+ is_last = (t_val == 0)
97
+
98
+ logits, _ = inner.forward_cached(
99
+ memory, src_pad_mask, x0_est, t,
100
+ x0_hint=hint, inference_mode=True,
101
+ )
102
+
103
+ logits = logits / max(temperature, 1e-8)
104
+ if top_k > 0:
105
+ V = logits.shape[-1]
106
+ if top_k < V:
107
+ topk_vals, _ = torch.topk(logits, top_k, dim=-1)
108
+ threshold = topk_vals[..., -1].unsqueeze(-1)
109
+ logits = logits.masked_fill(logits < threshold, float('-inf'))
110
+
111
+ probs = F.softmax(logits, dim=-1)
112
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
113
+ hint = x0_est
114
+
115
+ # Capture at this step
116
+ if (T - 1 - t_val) % capture_every == 0 or is_last:
117
+ ids = [x for x in x0_est[0].tolist() if x > 4]
118
+ text = tgt_tokenizer.decode(ids).strip()
119
+ step_outputs[t_val] = text
120
+
121
+ final_output = step_outputs.get(0, "")
122
+ return step_outputs, final_output
123
+
124
+
125
+ def _sample(probs):
126
+ B, L, V = probs.shape
127
+ flat = probs.view(B * L, V).clamp(min=1e-9)
128
+ flat = flat / flat.sum(dim=-1, keepdim=True)
129
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
130
+
131
+
132
+ def compute_drift(
133
+ step_outputs: Dict[int, str],
134
+ final_output: str,
135
+ ) -> Dict[str, object]:
136
+ """
137
+ Compute drift metrics comparing each intermediate output to the final.
138
+
139
+ Returns dict with:
140
+ t_vals : list of captured timesteps (T-1 → 0)
141
+ cer_to_final: CER between each step's output and the final output
142
+ 0.0 = identical to final, 1.0 = completely different
143
+ lock_in_t : first t_val where CER drops and stays below 0.1
144
+ (step at which output "commits" to final form)
145
+ """
146
+ t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
147
+ cer_to_final = []
148
+
149
+ for t_val in t_vals:
150
+ cer = compute_cer_between(step_outputs[t_val], final_output)
151
+ cer_to_final.append(cer)
152
+
153
+ # Find lock-in: first step where CER stays below threshold for rest of run
154
+ threshold = 0.1
155
+ lock_in_t = 0 # default: never locked in early
156
+ for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)):
157
+ if all(c <= threshold for c in cer_to_final[i:]):
158
+ lock_in_t = t_val
159
+ break
160
+
161
+ return {
162
+ "t_vals": t_vals,
163
+ "cer_to_final": cer_to_final,
164
+ "lock_in_t": lock_in_t,
165
+ "final_output": final_output,
166
+ }
167
+
168
+
169
+ def compute_token_stability(
170
+ step_outputs: Dict[int, str],
171
+ final_output: str,
172
+ tgt_tokenizer,
173
+ ) -> Dict[str, object]:
174
+ """
175
+ Token-level stability: for each position, at which diffusion step
176
+ does it first match its final token and stay matched?
177
+
178
+ Returns:
179
+ position_lock_times: list of t_val at which each position locks in
180
+ mean_lock_t : average lock-in timestep across positions
181
+ """
182
+ T = max(step_outputs.keys())
183
+ t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
184
+
185
+ # Encode all intermediate outputs and the final
186
+ def encode(text):
187
+ return tgt_tokenizer.encode(text)
188
+
189
+ final_ids = encode(final_output)
190
+ L = len(final_ids)
191
+
192
+ # Build matrix: [n_steps, L]
193
+ step_ids = []
194
+ for t_val in t_vals:
195
+ step_ids.append(encode(step_outputs.get(t_val, "")))
196
+
197
+ # Pad all to same length
198
+ max_len = max(len(s) for s in step_ids)
199
+ step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] # 1=PAD
200
+ final_ids_padded = final_ids + [1] * (max_len - len(final_ids))
201
+
202
+ step_arr = np.array(step_ids) # [n_steps, L]
203
+ final_arr = np.array(final_ids_padded) # [L]
204
+
205
+ # For each position: find first step index where it matches final
206
+ # and stays matched for all subsequent steps
207
+ position_lock_steps = []
208
+ for pos in range(min(L, max_len)):
209
+ col = step_arr[:, pos] # [n_steps]
210
+ fin = final_arr[pos]
211
+ locked_at = len(t_vals) - 1 # default: never locks early
212
+ for i in range(len(t_vals)):
213
+ if all(col[i:] == fin):
214
+ locked_at = i
215
+ break
216
+ position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0)
217
+
218
+ return {
219
+ "position_lock_times": position_lock_steps,
220
+ "mean_lock_t": float(np.mean(position_lock_steps)),
221
+ "std_lock_t": float(np.std(position_lock_steps)),
222
+ }
223
+
224
+
225
+ def plot_drift_curve(
226
+ drift_result: Dict,
227
+ src_text: str = "",
228
+ save_path: Optional[str] = None,
229
+ ):
230
+ """
231
+ Plot CER-to-final vs diffusion step.
232
+ Shows where the model "commits" to the final output.
233
+ """
234
+ try:
235
+ import matplotlib.pyplot as plt
236
+ except ImportError:
237
+ print("pip install matplotlib.")
238
+ return
239
+
240
+ t_vals = drift_result["t_vals"]
241
+ cers = drift_result["cer_to_final"]
242
+ lock_t = drift_result["lock_in_t"]
243
+
244
+ fig, ax = plt.subplots(figsize=(12, 4))
245
+ ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final')
246
+ ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral')
247
+
248
+ # Mark lock-in point
249
+ if lock_t in t_vals:
250
+ lock_idx = t_vals.index(lock_t)
251
+ ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2,
252
+ label=f"Lock-in at t={lock_t}")
253
+
254
+ ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
255
+
256
+ n = len(t_vals)
257
+ tick_positions = list(range(0, n, max(1, n // 10)))
258
+ ax.set_xticks(tick_positions)
259
+ ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8)
260
+ ax.set_xlabel("Diffusion step t (T-1 → 0)", fontsize=11)
261
+ ax.set_ylabel("CER vs final output", fontsize=11)
262
+ ax.set_ylim(0, 1.05)
263
+ ax.set_xlim(0, n - 1)
264
+ ax.legend(fontsize=10)
265
+
266
+ title = f"Semantic drift"
267
+ if src_text:
268
+ title += f" | src: {src_text[:50]}"
269
+ ax.set_title(title, fontsize=11)
270
+ plt.tight_layout()
271
+
272
+ if save_path:
273
+ import os
274
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
275
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
276
+ print(f"Saved: {save_path}")
277
+ else:
278
+ plt.show()
279
+ plt.close()
step_ablation.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/step_ablation.py
3
+ ==========================
4
+ Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
5
+
6
+ Two-phase workflow (retraining IS required for different T values):
7
+
8
+ PHASE 1 — Generate configs + train (run once per T value):
9
+ python analysis/step_ablation.py --phase generate_configs
10
+ # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
11
+ # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
12
+
13
+ PHASE 2 — Analyze trained models (no retraining needed):
14
+ python analysis/step_ablation.py --phase analyze
15
+ # Loads each trained model, generates 200 paraphrases, computes CER
16
+ # Produces 3D plot: X=steps, Y=generation_speed, Z=CER
17
+
18
+ Why retraining is needed:
19
+ A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
20
+ Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
21
+ was never trained on at those scales. Outputs are meaningless.
22
+ You must train a separate model for each T value.
23
+
24
+ Also implements adversarial robustness test (no retraining):
25
+ Takes your existing T=128 model and tests whether corrupted IAST
26
+ inputs (typos, character swaps) cause proportional output degradation.
27
+ """
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+ import os
33
+ import sys
34
+ import time
35
+ import json
36
+ import copy
37
+ from typing import List, Dict, Optional
38
+
39
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+
41
+
42
+ # ── Phase 1: Config generation ────────────────────────────────────────
43
+
44
+ T_VALUES = [4, 8, 16, 32, 64]
45
+
46
+ def generate_ablation_configs(base_config_path: str = "config.py",
47
+ output_dir: str = "ablation_configs"):
48
+ """
49
+ Generate one config file per T value.
50
+ Each config is a copy of the base config with diffusion_steps changed.
51
+
52
+ After running this, train each model:
53
+ for T in 4 8 16 32 64; do
54
+ cp ablation_configs/config_T${T}.py config.py
55
+ python train.py
56
+ mv results7/d3pm_cross_attention_neg_False \
57
+ ablation_results/T${T}
58
+ done
59
+ """
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Read base config
63
+ with open(base_config_path, "r") as f:
64
+ base_src = f.read()
65
+
66
+ for T in T_VALUES:
67
+ # Replace diffusion_steps and num_steps
68
+ cfg_src = base_src
69
+ cfg_src = cfg_src.replace(
70
+ '"diffusion_steps": 128',
71
+ f'"diffusion_steps": {T}'
72
+ )
73
+ cfg_src = cfg_src.replace(
74
+ "'diffusion_steps': 128",
75
+ f"'diffusion_steps': {T}"
76
+ )
77
+ cfg_src = cfg_src.replace(
78
+ '"num_steps": 128',
79
+ f'"num_steps": {T}'
80
+ )
81
+ cfg_src = cfg_src.replace(
82
+ "'num_steps': 128",
83
+ f"'num_steps': {T}"
84
+ )
85
+ out_path = os.path.join(output_dir, f"config_T{T}.py")
86
+ with open(out_path, "w") as f:
87
+ f.write(f"# Ablation config: T={T} diffusion steps\n")
88
+ f.write(cfg_src)
89
+ print(f" Wrote: {out_path}")
90
+
91
+ # Write a shell script to train all
92
+ shell_script = os.path.join(output_dir, "train_all.sh")
93
+ with open(shell_script, "w") as f:
94
+ f.write("#!/bin/bash\n")
95
+ f.write("# Run this script to train all ablation models\n\n")
96
+ for T in T_VALUES:
97
+ f.write(f"echo '=== Training T={T} ==='\n")
98
+ f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
99
+ f.write(f"python train.py\n")
100
+ f.write(f"mkdir -p ablation_results/T{T}\n")
101
+ f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
102
+ f"ablation_results/T{T}/best_model.pt\n")
103
+ f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
104
+ f"ablation_results/T{T}/train.log\n\n")
105
+ os.chmod(shell_script, 0o755)
106
+ print(f"\nTraining script: {shell_script}")
107
+ print(f"Run: bash {shell_script}")
108
+
109
+
110
+ # ── Phase 2: Analysis (after models are trained) ──────────────────────
111
+
112
+ def compute_cer(pred: str, ref: str) -> float:
113
+ if not ref:
114
+ return 1.0
115
+
116
+ def edit_distance(s1, s2):
117
+ m, n = len(s1), len(s2)
118
+ dp = list(range(n + 1))
119
+ for i in range(1, m + 1):
120
+ prev, dp[0] = dp[0], i
121
+ for j in range(1, n + 1):
122
+ temp = dp[j]
123
+ dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
124
+ prev = temp
125
+ return dp[n]
126
+
127
+ return edit_distance(pred, ref) / max(len(ref), 1)
128
+
129
+
130
+ def evaluate_model(
131
+ model,
132
+ src_list: List[torch.Tensor],
133
+ ref_list: List[str],
134
+ tgt_tokenizer,
135
+ n_samples: int = 200,
136
+ temperature: float = 0.8,
137
+ top_k: int = 40,
138
+ ) -> Dict:
139
+ """
140
+ Generate n_samples outputs and compute CER + generation speed.
141
+
142
+ Returns dict with:
143
+ mean_cer : average CER over samples
144
+ generation_s : total wall-clock seconds for all generations
145
+ speed_per_sample: seconds per sample
146
+ cer_list : per-sample CER values
147
+ """
148
+ device = next(model.parameters()).device
149
+ n = min(n_samples, len(src_list))
150
+ cer_list = []
151
+
152
+ start = time.perf_counter()
153
+ for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
154
+ if src.dim() == 1:
155
+ src = src.unsqueeze(0)
156
+
157
+ with torch.no_grad():
158
+ if hasattr(model.model, 'generate_cached'):
159
+ out = model.model.generate_cached(
160
+ src.to(device), temperature=temperature, top_k=top_k
161
+ )
162
+ else:
163
+ out = model.generate(
164
+ src.to(device), temperature=temperature, top_k=top_k
165
+ )
166
+
167
+ ids = [x for x in out[0].tolist() if x > 4]
168
+ pred = tgt_tokenizer.decode(ids).strip()
169
+ cer = compute_cer(pred, ref)
170
+ cer_list.append(cer)
171
+
172
+ elapsed = time.perf_counter() - start
173
+
174
+ return {
175
+ "mean_cer": float(np.mean(cer_list)),
176
+ "std_cer": float(np.std(cer_list)),
177
+ "generation_s": elapsed,
178
+ "speed_per_sample": elapsed / max(n, 1),
179
+ "cer_list": cer_list,
180
+ "n_samples": n,
181
+ }
182
+
183
+
184
+ def run_ablation_analysis(
185
+ ablation_dir: str = "ablation_results",
186
+ base_cfg: dict = None,
187
+ src_list: List[torch.Tensor] = None,
188
+ ref_list: List[str] = None,
189
+ tgt_tokenizer = None,
190
+ device: torch.device = None,
191
+ output_dir: str = "analysis/outputs",
192
+ ) -> Dict:
193
+ """
194
+ Load each trained model and evaluate.
195
+ Produces results dict and 3D plot.
196
+
197
+ Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
198
+ """
199
+ from inference import load_model
200
+
201
+ results = {}
202
+ for T in T_VALUES:
203
+ ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
204
+ if not os.path.exists(ckpt):
205
+ print(f" SKIP T={T}: no checkpoint at {ckpt}")
206
+ continue
207
+
208
+ print(f"\nEvaluating T={T}...")
209
+ cfg_T = copy.deepcopy(base_cfg)
210
+ cfg_T['model']['diffusion_steps'] = T
211
+ cfg_T['inference']['num_steps'] = T
212
+
213
+ model, cfg_T = load_model(ckpt, cfg_T, device)
214
+ model.eval()
215
+
216
+ metrics = evaluate_model(
217
+ model, src_list, ref_list, tgt_tokenizer, n_samples=200
218
+ )
219
+ results[T] = metrics
220
+ print(f" T={T} CER={metrics['mean_cer']:.4f} "
221
+ f"speed={metrics['speed_per_sample']:.3f}s/sample")
222
+
223
+ del model
224
+
225
+ # Save results
226
+ os.makedirs(output_dir, exist_ok=True)
227
+ results_path = os.path.join(output_dir, "ablation_results.json")
228
+ with open(results_path, "w") as f:
229
+ json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
230
+ for k, v in results.items()}, f, indent=2)
231
+ print(f"\nResults saved: {results_path}")
232
+
233
+ return results
234
+
235
+
236
+ def plot_ablation_3d(
237
+ results: Dict,
238
+ save_path: Optional[str] = None,
239
+ ):
240
+ """
241
+ 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
242
+ Also produces a 2D summary plot.
243
+ """
244
+ try:
245
+ import matplotlib.pyplot as plt
246
+ from mpl_toolkits.mplot3d import Axes3D
247
+ except ImportError:
248
+ print("pip install matplotlib.")
249
+ return
250
+
251
+ T_list = sorted(results.keys())
252
+ cers = [results[T]["mean_cer"] for T in T_list]
253
+ speeds = [results[T]["speed_per_sample"] for T in T_list]
254
+
255
+ # ── 3D plot ───────────────────────────────────────────────────────
256
+ fig = plt.figure(figsize=(14, 5))
257
+
258
+ ax3d = fig.add_subplot(121, projection='3d')
259
+ ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
260
+ for T, s, c in zip(T_list, speeds, cers):
261
+ ax3d.text(T, s, c, f"T={T}", fontsize=8)
262
+ ax3d.set_xlabel("Diffusion steps T", fontsize=9)
263
+ ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
264
+ ax3d.set_zlabel("CER (↓ better)", fontsize=9)
265
+ ax3d.set_title("T vs speed vs CER", fontsize=10)
266
+
267
+ # ── 2D CER vs T (find the knee) ──────────────────────────────────
268
+ ax2d = fig.add_subplot(122)
269
+ ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
270
+ for T, c in zip(T_list, cers):
271
+ ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
272
+ xytext=(0, 8), fontsize=8, ha='center')
273
+
274
+ # Find knee: largest CER drop per unit T (elbow method)
275
+ if len(T_list) >= 3:
276
+ drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
277
+ knee_i = int(np.argmax(drops))
278
+ knee_T = T_list[knee_i + 1]
279
+ ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
280
+ label=f"Knee at T={knee_T}")
281
+ ax2d.legend(fontsize=9)
282
+
283
+ ax2d.set_xlabel("Diffusion steps T", fontsize=10)
284
+ ax2d.set_ylabel("CER (lower = better)", fontsize=10)
285
+ ax2d.set_title("CER vs diffusion steps", fontsize=10)
286
+ ax2d.set_ylim(0, max(cers) * 1.1)
287
+
288
+ plt.tight_layout()
289
+ if save_path:
290
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
291
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
292
+ print(f"Saved: {save_path}")
293
+ else:
294
+ plt.show()
295
+ plt.close()
296
+
297
+
298
+ # ── Adversarial robustness test (no retraining needed) ───────────────
299
+
300
+ def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
301
+ """
302
+ Introduce random corruption into IAST text:
303
+ - Character swap (adjacent chars swapped)
304
+ - Character deletion
305
+ - Random character insertion
306
+
307
+ Models rate as 5% to 20% corruption to test robustness.
308
+ """
309
+ import random
310
+ chars = list(text)
311
+ n_corrupt = max(1, int(len(chars) * corruption_rate))
312
+
313
+ for _ in range(n_corrupt):
314
+ op = random.choice(['swap', 'delete', 'insert'])
315
+ pos = random.randint(0, len(chars) - 1)
316
+
317
+ if op == 'swap' and pos < len(chars) - 1:
318
+ chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
319
+ elif op == 'delete' and len(chars) > 1:
320
+ chars.pop(pos)
321
+ elif op == 'insert':
322
+ chars.insert(pos, random.choice('abcdeimnostu'))
323
+
324
+ return "".join(chars)
325
+
326
+
327
+ @torch.no_grad()
328
+ def run_adversarial_test(
329
+ model,
330
+ src_tokenizer,
331
+ tgt_tokenizer,
332
+ test_inputs: List[str],
333
+ test_refs: List[str],
334
+ corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
335
+ device: torch.device = None,
336
+ output_dir: str = "analysis/outputs",
337
+ ) -> Dict:
338
+ """
339
+ Test if CER degrades proportionally with IAST corruption.
340
+ Uses existing trained model — no retraining.
341
+ """
342
+ device = device or next(model.parameters()).device
343
+ results = {}
344
+
345
+ print("\nAdversarial robustness test...")
346
+ for rate in corruption_rates:
347
+ cer_list = []
348
+ for text, ref in zip(test_inputs, test_refs):
349
+ corrupted = corrupt_iast(text, rate)
350
+ ids = src_tokenizer.encode(corrupted)
351
+ src = torch.tensor([ids], dtype=torch.long, device=device)
352
+
353
+ if hasattr(model.model, 'generate_cached'):
354
+ out = model.model.generate_cached(src)
355
+ else:
356
+ out = model.generate(src)
357
+
358
+ pred_ids = [x for x in out[0].tolist() if x > 4]
359
+ pred = tgt_tokenizer.decode(pred_ids).strip()
360
+ cer_list.append(compute_cer(pred, ref))
361
+
362
+ mean_cer = float(np.mean(cer_list))
363
+ results[rate] = mean_cer
364
+ print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
365
+
366
+ # Save + plot
367
+ os.makedirs(output_dir, exist_ok=True)
368
+ try:
369
+ import matplotlib.pyplot as plt
370
+ fig, ax = plt.subplots(figsize=(8, 4))
371
+ rates = [r * 100 for r in corruption_rates]
372
+ cers = [results[r] for r in corruption_rates]
373
+ ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
374
+ ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
375
+ ax.set_ylabel("CER", fontsize=11)
376
+ ax.set_title("Model robustness to IAST input corruption", fontsize=11)
377
+ ax.set_ylim(0, max(cers) * 1.2)
378
+ plt.tight_layout()
379
+ plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
380
+ dpi=150, bbox_inches='tight')
381
+ plt.close()
382
+ print(f" Saved: {output_dir}/adversarial_robustness.png")
383
+ except ImportError:
384
+ pass
385
+
386
+ with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
387
+ json.dump({str(k): v for k, v in results.items()}, f, indent=2)
388
+
389
+ return results
tokenizer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — Dual Tokenizer Fix
3
+ ====================================
4
+ Two separate BPE tokenizers:
5
+
6
+ SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
7
+ SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
8
+
9
+ WHY SEPARATE?
10
+ Roman Sanskrit and Devanagari are fundamentally different character sets.
11
+ Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
12
+ (~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
13
+ character combos that never cross scripts, and forces the embedding table
14
+ to encode both scripts in one space — confusing the model's cross-attention.
15
+
16
+ With separate tokenizers:
17
+ - src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
18
+ - tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
19
+ - The model learns a true cross-script mapping in its cross-attention
20
+
21
+ SPECIAL TOKENS (same IDs in both):
22
+ [MASK] = 0 ← required by absorbing diffusion
23
+ [PAD] = 1
24
+ [UNK] = 2
25
+ [CLS] = 3
26
+ [SEP] = 4
27
+ """
28
+
29
+ from tokenizers import Tokenizer
30
+ from tokenizers.models import BPE
31
+ from tokenizers.trainers import BpeTrainer
32
+ from tokenizers.pre_tokenizers import Whitespace
33
+ from datasets import load_dataset
34
+ from pathlib import Path
35
+
36
+
37
+ SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
38
+
39
+
40
+ def _build_bpe(texts, vocab_size):
41
+ """Build a BPE tokenizer from an iterator of strings."""
42
+ tok = Tokenizer(BPE(unk_token="[UNK]"))
43
+ tok.pre_tokenizer = Whitespace()
44
+ trainer = BpeTrainer(
45
+ vocab_size=vocab_size,
46
+ special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
47
+ min_frequency=2,
48
+ )
49
+ tok.train_from_iterator(texts, trainer)
50
+ return tok
51
+
52
+
53
+ def _validate(tok, name):
54
+ mask_id = tok.token_to_id("[MASK]")
55
+ pad_id = tok.token_to_id("[PAD]")
56
+ assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
57
+ assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
58
+ print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
59
+
60
+
61
+ # ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
62
+
63
+ class SanskritSourceTokenizer:
64
+ """
65
+ Tokenizer for quote_text — Roman transliteration of Sanskrit.
66
+ Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
67
+ """
68
+ MODEL_PATH = "sanskrit_src_tokenizer.json"
69
+
70
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
71
+ self.vocab_size = vocab_size
72
+ self.max_len = max_len
73
+ self.mask_token_id = 0
74
+
75
+ if Path(self.MODEL_PATH).exists():
76
+ print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
77
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
78
+ else:
79
+ print("🎓 Training source tokenizer on quote_text …")
80
+ self._train(vocab_size, n_train_samples)
81
+
82
+ _validate(self.tokenizer, "SrcTokenizer")
83
+
84
+ def _train(self, vocab_size, n_samples):
85
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
86
+ n = min(n_samples, len(dataset))
87
+ texts = [s["quote_text"] for s in dataset.select(range(n))
88
+ if s["quote_text"].strip()]
89
+ self.tokenizer = _build_bpe(texts, vocab_size)
90
+ self.tokenizer.save(self.MODEL_PATH)
91
+ print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
92
+
93
+ def encode(self, text):
94
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
95
+ pad = self.tokenizer.token_to_id("[PAD]")
96
+ ids += [pad] * max(0, self.max_len - len(ids))
97
+ return ids[:self.max_len]
98
+
99
+ def decode(self, ids):
100
+ clean = [i for i in ids if i > 4] # skip special tokens
101
+ return self.tokenizer.decode(clean)
102
+
103
+ def __len__(self):
104
+ return self.vocab_size
105
+
106
+
107
+ # ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
108
+
109
+ class SanskritTargetTokenizer:
110
+ """
111
+ Tokenizer for quote_devanagari — Devanagari script.
112
+ Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
113
+ """
114
+ MODEL_PATH = "sanskrit_tgt_tokenizer.json"
115
+
116
+ def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
117
+ self.vocab_size = vocab_size
118
+ self.max_len = max_len
119
+ self.mask_token_id = 0
120
+
121
+ if Path(self.MODEL_PATH).exists():
122
+ print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
123
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
124
+ else:
125
+ print("🎓 Training target tokenizer on quote_devanagari …")
126
+ self._train(vocab_size, n_train_samples)
127
+
128
+ _validate(self.tokenizer, "TgtTokenizer")
129
+
130
+ def _train(self, vocab_size, n_samples):
131
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
132
+ n = min(n_samples, len(dataset))
133
+ texts = [s["quote_devanagari"] for s in dataset.select(range(n))
134
+ if s["quote_devanagari"].strip()]
135
+ self.tokenizer = _build_bpe(texts, vocab_size)
136
+ self.tokenizer.save(self.MODEL_PATH)
137
+ print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
138
+
139
+ def encode(self, text):
140
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
141
+ pad = self.tokenizer.token_to_id("[PAD]")
142
+ ids += [pad] * max(0, self.max_len - len(ids))
143
+ return ids[:self.max_len]
144
+
145
+ def decode(self, ids):
146
+ clean = [i for i in ids if i > 4]
147
+ return self.tokenizer.decode(clean)
148
+
149
+ # Methods required by BERTScore
150
+ def build_inputs_with_special_tokens(self, token_ids):
151
+ return list(token_ids)
152
+
153
+ def get_vocab(self):
154
+ return {str(i): i for i in range(self.vocab_size)}
155
+
156
+ def convert_ids_to_tokens(self, ids):
157
+ return [str(i) for i in ids]
158
+
159
+ def __len__(self):
160
+ return self.vocab_size
161
+
162
+
163
+ # ── Legacy shared tokenizer (kept for backward compat) ───────────────
164
+
165
+ class SanskritTokenizer:
166
+ """
167
+ LEGACY: single shared tokenizer trained on BOTH scripts.
168
+ Still works but suboptimal — use SanskritSourceTokenizer +
169
+ SanskritTargetTokenizer for the quote_text → quote_devanagari task.
170
+ """
171
+ MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
172
+
173
+ def __init__(self, vocab_size=16000, max_len=80):
174
+ self.vocab_size = vocab_size
175
+ self.max_len = max_len
176
+ self.mask_token_id = 0
177
+
178
+ if Path(self.MODEL_PATH).exists():
179
+ print("📖 Loading shared tokenizer …")
180
+ self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
181
+ else:
182
+ print("🎓 Training shared tokenizer on both scripts …")
183
+ self._train(vocab_size)
184
+
185
+ _validate(self.tokenizer, "SharedTokenizer")
186
+
187
+ def _train(self, vocab_size):
188
+ dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
189
+ n = min(50000, len(dataset))
190
+ texts = []
191
+ for s in dataset.select(range(n)):
192
+ if s["quote_text"].strip():
193
+ texts.append(s["quote_text"])
194
+ if s["quote_devanagari"].strip():
195
+ texts.append(s["quote_devanagari"])
196
+ self.tokenizer = _build_bpe(texts, vocab_size)
197
+ self.tokenizer.save(self.MODEL_PATH)
198
+ print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
199
+
200
+ def encode(self, text):
201
+ ids = self.tokenizer.encode(text).ids[:self.max_len]
202
+ pad = self.tokenizer.token_to_id("[PAD]")
203
+ ids += [pad] * max(0, self.max_len - len(ids))
204
+ return ids[:self.max_len]
205
+
206
+ def decode(self, ids):
207
+ if ids and isinstance(ids[0], list):
208
+ raise TypeError("decode() got 2D list — pass a 1D list.")
209
+ clean = [i for i in ids if i > 4]
210
+ return self.tokenizer.decode(clean)
211
+
212
+ def build_inputs_with_special_tokens(self, token_ids):
213
+ return list(token_ids)
214
+
215
+ def get_vocab(self):
216
+ return {str(i): i for i in range(self.vocab_size)}
217
+
218
+ def convert_ids_to_tokens(self, ids):
219
+ return [str(i) for i in ids]
220
+
221
+ def __len__(self):
222
+ return self.vocab_size
train_all.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+ # Run this script to train all ablation models
4
+
5
+ MODEL_TYPE=${MODEL_TYPE:-d3pm_cross_attention}
6
+ INCLUDE_NEG=${INCLUDE_NEG:-False}
7
+ TRAIN_DEVICE=${TRAIN_DEVICE:-mps}
8
+
9
+ echo '=== Training T=4 ==='
10
+ mkdir -p ablation_results/T4
11
+ MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=4 INFERENCE_NUM_STEPS=4 TRAIN_OUTPUT_DIR="ablation_results/T4" python train.py
12
+
13
+ echo '=== Training T=8 ==='
14
+ mkdir -p ablation_results/T8
15
+ MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=8 INFERENCE_NUM_STEPS=8 TRAIN_OUTPUT_DIR="ablation_results/T8" python train.py
16
+
17
+ echo '=== Training T=16 ==='
18
+ mkdir -p ablation_results/T16
19
+ MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=16 INFERENCE_NUM_STEPS=16 TRAIN_OUTPUT_DIR="ablation_results/T16" python train.py
20
+
21
+ echo '=== Training T=32 ==='
22
+ mkdir -p ablation_results/T32
23
+ MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=32 INFERENCE_NUM_STEPS=32 TRAIN_OUTPUT_DIR="ablation_results/T32" python train.py
24
+
25
+ echo '=== Training T=64 ==='
26
+ mkdir -p ablation_results/T64
27
+ MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=64 INFERENCE_NUM_STEPS=64 TRAIN_OUTPUT_DIR="ablation_results/T64" python train.py
28
+