| """ |
| analysis/semantic_drift.py |
| =========================== |
| Task 2: Semantic drift metric β how much does the intermediate generation |
| diverge from the final output as we walk through diffusion steps T β 0? |
| |
| Metric: CER between x0_estimate at each step vs the final x0 at t=0. |
| |
| A well-trained model should show: |
| - High drift at t=T-1 (near-random initial estimate) |
| - Rapid decrease in drift around t=T//2 (model finds the right structure) |
| - Near-zero drift at t=10 (output is stable, only fine corrections remain) |
| |
| If drift stays high until t=5 then suddenly collapses β model is doing all |
| its work in the last few steps β consider reducing T. |
| |
| Also measures: |
| - Token stability: fraction of positions that don't change between steps |
| - Lock-in time: first step where each position "commits" to its final token |
| |
| No retraining required. Uses generate_cached() with intermediate snapshots. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import List, Dict, Optional, Tuple |
|
|
|
|
| def compute_cer_between(pred: str, ref: str) -> float: |
| """CER between two strings.""" |
| if not ref: |
| return 1.0 if pred else 0.0 |
|
|
| def edit_distance(s1, s2): |
| m, n = len(s1), len(s2) |
| dp = list(range(n + 1)) |
| for i in range(1, m + 1): |
| prev, dp[0] = dp[0], i |
| for j in range(1, n + 1): |
| temp = dp[j] |
| dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) |
| prev = temp |
| return dp[n] |
|
|
| return edit_distance(pred, ref) / len(ref) |
|
|
|
|
| @torch.no_grad() |
| def capture_intermediate_outputs( |
| model, |
| src: torch.Tensor, |
| tgt_tokenizer, |
| capture_every: int = 5, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> Tuple[Dict[int, str], str]: |
| """ |
| Run generation while recording the decoded x0_estimate at every |
| `capture_every` diffusion steps. |
| |
| Args: |
| model : SanskritModel (D3PMCrossAttention) |
| src : [1, src_len] IAST token ids (single sample) |
| tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs |
| capture_every : record every N steps |
| temperature : sampling temperature |
| top_k : top-k filter |
| |
| Returns: |
| step_outputs : dict mapping t_val β decoded Devanagari string at that step |
| final_output : decoded string at t=0 (final result) |
| """ |
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
|
|
| inner = model.model |
| T = inner.scheduler.num_timesteps |
| device = src.device |
|
|
| |
| memory, src_pad_mask = inner.encode_source(src) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
|
|
| step_outputs: Dict[int, str] = {} |
| inner.eval() |
|
|
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| logits, _ = inner.forward_cached( |
| memory, src_pad_mask, x0_est, t, |
| x0_hint=hint, inference_mode=True, |
| ) |
|
|
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| V = logits.shape[-1] |
| if top_k < V: |
| topk_vals, _ = torch.topk(logits, top_k, dim=-1) |
| threshold = topk_vals[..., -1].unsqueeze(-1) |
| logits = logits.masked_fill(logits < threshold, float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs) |
| hint = x0_est |
|
|
| |
| if (T - 1 - t_val) % capture_every == 0 or is_last: |
| ids = [x for x in x0_est[0].tolist() if x > 4] |
| text = tgt_tokenizer.decode(ids).strip() |
| step_outputs[t_val] = text |
|
|
| final_output = step_outputs.get(0, "") |
| return step_outputs, final_output |
|
|
|
|
| def _sample(probs): |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V).clamp(min=1e-9) |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |
|
|
|
|
| def compute_drift( |
| step_outputs: Dict[int, str], |
| final_output: str, |
| ) -> Dict[str, object]: |
| """ |
| Compute drift metrics comparing each intermediate output to the final. |
| |
| Returns dict with: |
| t_vals : list of captured timesteps (T-1 β 0) |
| cer_to_final: CER between each step's output and the final output |
| 0.0 = identical to final, 1.0 = completely different |
| lock_in_t : first t_val where CER drops and stays below 0.1 |
| (step at which output "commits" to final form) |
| """ |
| t_vals = sorted(step_outputs.keys(), reverse=True) |
| cer_to_final = [] |
|
|
| for t_val in t_vals: |
| cer = compute_cer_between(step_outputs[t_val], final_output) |
| cer_to_final.append(cer) |
|
|
| |
| threshold = 0.1 |
| lock_in_t = 0 |
| for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)): |
| if all(c <= threshold for c in cer_to_final[i:]): |
| lock_in_t = t_val |
| break |
|
|
| return { |
| "t_vals": t_vals, |
| "cer_to_final": cer_to_final, |
| "lock_in_t": lock_in_t, |
| "final_output": final_output, |
| } |
|
|
|
|
| def compute_token_stability( |
| step_outputs: Dict[int, str], |
| final_output: str, |
| tgt_tokenizer, |
| ) -> Dict[str, object]: |
| """ |
| Token-level stability: for each position, at which diffusion step |
| does it first match its final token and stay matched? |
| |
| Returns: |
| position_lock_times: list of t_val at which each position locks in |
| mean_lock_t : average lock-in timestep across positions |
| """ |
| T = max(step_outputs.keys()) |
| t_vals = sorted(step_outputs.keys(), reverse=True) |
|
|
| |
| def encode(text): |
| return tgt_tokenizer.encode(text) |
|
|
| final_ids = encode(final_output) |
| L = len(final_ids) |
|
|
| |
| step_ids = [] |
| for t_val in t_vals: |
| step_ids.append(encode(step_outputs.get(t_val, ""))) |
|
|
| |
| max_len = max(len(s) for s in step_ids) |
| step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] |
| final_ids_padded = final_ids + [1] * (max_len - len(final_ids)) |
|
|
| step_arr = np.array(step_ids) |
| final_arr = np.array(final_ids_padded) |
|
|
| |
| |
| position_lock_steps = [] |
| for pos in range(min(L, max_len)): |
| col = step_arr[:, pos] |
| fin = final_arr[pos] |
| locked_at = len(t_vals) - 1 |
| for i in range(len(t_vals)): |
| if all(col[i:] == fin): |
| locked_at = i |
| break |
| position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0) |
|
|
| return { |
| "position_lock_times": position_lock_steps, |
| "mean_lock_t": float(np.mean(position_lock_steps)), |
| "std_lock_t": float(np.std(position_lock_steps)), |
| } |
|
|
|
|
| def plot_drift_curve( |
| drift_result: Dict, |
| src_text: str = "", |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Plot CER-to-final vs diffusion step. |
| Shows where the model "commits" to the final output. |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("pip install matplotlib.") |
| return |
|
|
| t_vals = drift_result["t_vals"] |
| cers = drift_result["cer_to_final"] |
| lock_t = drift_result["lock_in_t"] |
|
|
| fig, ax = plt.subplots(figsize=(12, 4)) |
| ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final') |
| ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral') |
|
|
| |
| if lock_t in t_vals: |
| lock_idx = t_vals.index(lock_t) |
| ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2, |
| label=f"Lock-in at t={lock_t}") |
|
|
| ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7) |
|
|
| n = len(t_vals) |
| tick_positions = list(range(0, n, max(1, n // 10))) |
| ax.set_xticks(tick_positions) |
| ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8) |
| ax.set_xlabel("Diffusion step t (T-1 β 0)", fontsize=11) |
| ax.set_ylabel("CER vs final output", fontsize=11) |
| ax.set_ylim(0, 1.05) |
| ax.set_xlim(0, n - 1) |
| ax.legend(fontsize=10) |
|
|
| title = f"Semantic drift" |
| if src_text: |
| title += f" | src: {src_text[:50]}" |
| ax.set_title(title, fontsize=11) |
| plt.tight_layout() |
|
|
| if save_path: |
| import os |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|