devflow / analysis /quality_classifier.py
bhsinghgrid's picture
Upload folder using huggingface_hub
9d76bba verified
# """
# analysis/quality_classifier.py
# ================================
# Task 5: Classifier-Free Guidance for Paraphrase Quality Control
#
# Two steps β€” only Step 2 requires training a SMALL model (not the main D3PM):
#
# STEP 1 β€” Collect training data (no training):
# Run existing model on val set, record (hidden_state, CER) pairs.
# Hidden states come from model.model._last_hidden after forward_cached().
# CER score = quality label (lower CER = higher quality).
#
# STEP 2 β€” Train quality classifier:
# Small 2-layer MLP: d_model β†’ 64 β†’ 1
# Input: pooled decoder hidden state [B, d_model]
# Output: predicted quality score in [0, 1] (1 = high quality)
# Loss: MSE against normalized CER labels
# Training time: ~5-10 minutes on CPU for 10k examples
#
# STEP 3 β€” Guided inference (no retraining):
# At each diffusion step, use classifier gradient to shift logits:
# guided_logits = logits + Ξ» * βˆ‚(quality_score)/βˆ‚(logits)
# Higher Ξ» β†’ model biased toward high-quality outputs
# Ξ»=0 β†’ standard generation (no guidance)
#
# Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
# """
#
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import os
# import json
# from typing import List, Dict, Optional, Tuple
#
#
# # ── Quality classifier architecture ──────────────────────────────────
#
# class QualityClassifier(nn.Module):
# """
# Lightweight MLP that predicts transliteration quality from decoder
# hidden states.
#
# Architecture:
# d_model β†’ 128 β†’ 64 β†’ 1 β†’ Sigmoid
#
# Input: mean-pooled decoder hidden state [B, d_model]
# Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
#
# ~10k parameters. Trains in minutes on CPU.
# """
# def __init__(self, d_model: int):
# super().__init__()
# self.net = nn.Sequential(
# nn.Linear(d_model, 128),
# nn.ReLU(),
# nn.Dropout(0.1),
# nn.Linear(128, 64),
# nn.ReLU(),
# nn.Linear(64, 1),
# nn.Sigmoid(),
# )
# self.d_model = d_model
#
# def forward(self, hidden: torch.Tensor) -> torch.Tensor:
# """
# Args:
# hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
#
# Returns:
# score : [B, 1] quality score in [0, 1]
# """
# if hidden.dim() == 3:
# # Pool over sequence length
# hidden = hidden.mean(dim=1) # [B, d_model]
# return self.net(hidden) # [B, 1]
#
#
# # ── Training data collection ──────────────────────────────────────────
#
# @torch.no_grad()
# def collect_quality_data(
# model,
# src_list: List[torch.Tensor],
# ref_list: List[str],
# tgt_tokenizer,
# t_capture: int = 0,
# temperature: float = 0.8,
# top_k: int = 40,
# max_samples: int = 5000,
# ) -> Tuple[np.ndarray, np.ndarray]:
# """
# Collect (hidden_state, quality_score) pairs for classifier training.
#
# For each sample:
# 1. Run generate_cached() on src
# 2. Capture decoder hidden state at t=t_capture
# 3. Compute CER between output and reference
# 4. Quality = 1 - CER (normalize to [0,1])
#
# Args:
# model : SanskritModel
# src_list : list of [1, src_len] tensors
# ref_list : list of reference Devanagari strings
# tgt_tokenizer : SanskritTargetTokenizer
# t_capture : which step to capture hidden states (0 = final)
# max_samples : cap number of training examples
#
# Returns:
# hidden_matrix : np.ndarray [N, d_model]
# quality_scores: np.ndarray [N] values in [0, 1]
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
#
# hidden_list = []
# quality_list = []
# n = min(len(src_list), max_samples)
#
# def cer(pred, ref):
# if not ref:
# return 1.0
# def ed(s1, s2):
# m, n = len(s1), len(s2)
# dp = list(range(n + 1))
# for i in range(1, m + 1):
# prev, dp[0] = dp[0], i
# for j in range(1, n + 1):
# temp = dp[j]
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
# prev = temp
# return dp[n]
# return ed(pred, ref) / max(len(ref), 1)
#
# print(f"Collecting quality data from {n} examples...")
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
# if i % 200 == 0:
# print(f" {i}/{n}")
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# memory, src_pad_mask = inner.encode_source(src)
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
# h_cap = None
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
# h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
#
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
# hint = x0_est
#
# if h_cap is None:
# continue
#
# ids = [x for x in x0_est[0].tolist() if x > 4]
# pred = tgt_tokenizer.decode(ids).strip()
# q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
#
# hidden_list.append(h_cap.numpy())
# quality_list.append(q)
#
# print(f"Collected {len(hidden_list)} quality examples.")
# print(f"Quality stats: mean={np.mean(quality_list):.3f} "
# f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
#
# return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
#
#
# def _sample(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
#
#
# # ── Training ──────────────────────────────────────────────────────────
#
# def train_quality_classifier(
# hidden_matrix: np.ndarray,
# quality_scores: np.ndarray,
# d_model: int,
# epochs: int = 30,
# batch_size: int = 64,
# lr: float = 1e-3,
# val_frac: float = 0.1,
# save_path: Optional[str] = None,
# ) -> QualityClassifier:
# """
# Train QualityClassifier on collected (hidden, quality) pairs.
#
# Args:
# hidden_matrix : [N, d_model] from collect_quality_data()
# quality_scores : [N] quality labels in [0, 1]
# d_model : hidden dimension
# epochs : training epochs
# save_path : if given, save trained classifier weights here
#
# Returns:
# trained QualityClassifier
# """
# device = torch.device("cpu") # classifier is tiny, CPU is fine
#
# X = torch.tensor(hidden_matrix, dtype=torch.float32)
# y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
#
# N = len(X)
# n_val = max(1, int(N * val_frac))
# idx = torch.randperm(N)
# val_idx = idx[:n_val]
# train_idx = idx[n_val:]
#
# X_train, y_train = X[train_idx], y[train_idx]
# X_val, y_val = X[val_idx], y[val_idx]
#
# clf = QualityClassifier(d_model).to(device)
# optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
#
# print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
# print(f"Train: {len(X_train)} Val: {len(X_val)}")
#
# best_val_loss = float('inf')
# best_state = None
#
# for epoch in range(epochs):
# clf.train()
# perm = torch.randperm(len(X_train))
# train_loss = 0.0
# n_batches = 0
#
# for start in range(0, len(X_train), batch_size):
# batch_idx = perm[start:start + batch_size]
# xb, yb = X_train[batch_idx], y_train[batch_idx]
# pred = clf(xb)
# loss = F.mse_loss(pred, yb)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# train_loss += loss.item()
# n_batches += 1
#
# clf.eval()
# with torch.no_grad():
# val_pred = clf(X_val)
# val_loss = F.mse_loss(val_pred, y_val).item()
#
# if epoch % 5 == 0 or epoch == epochs - 1:
# print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
#
# if val_loss < best_val_loss:
# best_val_loss = val_loss
# best_state = {k: v.clone() for k, v in clf.state_dict().items()}
#
# if best_state:
# clf.load_state_dict(best_state)
# print(f" Best val loss: {best_val_loss:.4f}")
#
# if save_path:
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
# torch.save(clf.state_dict(), save_path)
# print(f" Classifier saved: {save_path}")
#
# return clf
#
#
# # ── Guided inference ──────────────────────────────────────────────────
#
# def generate_guided(
# model,
# src: torch.Tensor,
# classifier: QualityClassifier,
# guidance_scale: float = 1.0,
# temperature: float = 0.8,
# top_k: int = 40,
# ) -> torch.Tensor:
# """
# Classifier-guided generation.
#
# At each diffusion step:
# 1. Run forward_cached() β†’ logits, hidden states
# 2. Compute classifier gradient: βˆ‚(quality_score) / βˆ‚(hidden)
# 3. Project gradient back to logit space (approximate)
# 4. guided_logits = logits + Ξ» * gradient_signal
# 5. Sample from guided_logits
#
# guidance_scale Ξ»:
# 0.0 β†’ no guidance (standard generation)
# 0.5 β†’ weak guidance
# 1.0 β†’ moderate guidance (recommended starting point)
# 2.0 β†’ strong guidance (may reduce diversity)
# 3.0 β†’ very strong (may collapse to repetitive output)
#
# Args:
# model : SanskritModel (frozen)
# src : [1, src_len] IAST token ids
# classifier : trained QualityClassifier
# guidance_scale : Ξ» β€” guidance strength
#
# Returns:
# x0_est : [1, tgt_len] generated token ids
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
# clf_device = next(classifier.parameters()).device
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# memory, src_pad_mask = inner.encode_source(src)
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
#
# inner.eval()
# classifier.eval()
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# if guidance_scale > 0.0:
# # Need gradients for classifier guidance
# with torch.enable_grad():
# # Run forward_cached and get hidden states
# PAD = 1
# if t_val > 0:
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
# else:
# x_t_ids = x0_est
#
# x = inner.tgt_embed(x_t_ids)
# t_norm = t.float() / T
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
# x = x + t_emb.unsqueeze(1)
#
# if hint is not None:
# hint_emb = inner.tgt_embed(hint)
# gate = inner.hint_gate(x)
# x = x + gate * hint_emb
#
# for block in inner.decoder_blocks:
# x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
#
# # hidden: [B, tgt_len, d_model] β€” detach from graph for clf
# hidden = x.detach().requires_grad_(True).to(clf_device)
#
# # Classifier quality score
# quality = classifier(hidden) # [B, 1]
# quality.sum().backward()
#
# # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
# grad = hidden.grad.to(device) # [B, tgt_len, d_model]
#
# # Project gradient to logit space via output head weight
# # logit_grad β‰ˆ grad @ head.weight [B, tgt_len, tgt_vocab]
# logit_grad = grad @ inner.head.weight.T
#
# # Compute standard logits (no gradient needed)
# with torch.no_grad():
# logits = inner.head(x)
#
# # Apply guidance
# logits = logits + guidance_scale * logit_grad
#
# else:
# with torch.no_grad():
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# with torch.no_grad():
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
# hint = x0_est
#
# return x0_est
#
#
# def _sample_no_grad(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
#
#
# # ── Guidance scale sweep ──────────────────────────────────────────────
#
# def sweep_guidance_scales(
# model,
# classifier: QualityClassifier,
# src_list: List[torch.Tensor],
# ref_list: List[str],
# tgt_tokenizer,
# scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
# n_samples: int = 50,
# device: torch.device = None,
# output_dir: str = "analysis/outputs",
# ) -> Dict:
# """
# Evaluate CER at each guidance scale.
# Produces quality-diversity tradeoff plot.
# """
# def cer(pred, ref):
# if not ref:
# return 1.0
# def ed(s1, s2):
# m, n = len(s1), len(s2)
# dp = list(range(n + 1))
# for i in range(1, m + 1):
# prev, dp[0] = dp[0], i
# for j in range(1, n + 1):
# temp = dp[j]
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
# prev = temp
# return dp[n]
# return ed(pred, ref) / max(len(ref), 1)
#
# device = device or next(model.parameters()).device
# results = {}
# n = min(n_samples, len(src_list))
#
# print("\nGuidance scale sweep...")
# for scale in scales:
# cer_list = []
# output_set = []
# for src, ref in zip(src_list[:n], ref_list[:n]):
# if src.dim() == 1:
# src = src.unsqueeze(0)
# out = generate_guided(model, src.to(device), classifier,
# guidance_scale=scale)
# ids = [x for x in out[0].tolist() if x > 4]
# pred = tgt_tokenizer.decode(ids).strip()
# cer_list.append(cer(pred, ref))
# output_set.append(pred)
#
# mean_cer = float(np.mean(cer_list))
#
# # Self-diversity: unique outputs / total (proxy for diversity)
# unique_frac = len(set(output_set)) / max(len(output_set), 1)
#
# results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
# print(f" Ξ»={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
#
# # Plot
# os.makedirs(output_dir, exist_ok=True)
# try:
# import matplotlib.pyplot as plt
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
#
# sc_list = sorted(results.keys())
# cers = [results[s]["mean_cer"] for s in sc_list]
# diversities = [results[s]["diversity"] for s in sc_list]
#
# ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
# ax1.set_xlabel("Guidance scale Ξ»", fontsize=10)
# ax1.set_ylabel("CER (↓ better)", fontsize=10)
# ax1.set_title("Quality vs guidance scale", fontsize=10)
#
# ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
# ax2.set_xlabel("Guidance scale Ξ»", fontsize=10)
# ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
# ax2.set_title("Diversity vs guidance scale", fontsize=10)
#
# plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
# plt.tight_layout()
# path = os.path.join(output_dir, "guidance_scale_sweep.png")
# plt.savefig(path, dpi=150, bbox_inches='tight')
# plt.close()
# print(f" Saved: {path}")
# except ImportError:
# pass
#
# with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
#
# return results
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict
from itertools import combinations
class QualityClassifier(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, hidden):
if hidden.dim() == 3:
hidden = hidden.mean(dim=1)
return self.net(hidden)
def _cer(pred: str, ref: str) -> float:
m, n = len(pred), len(ref)
if m == 0 and n == 0:
return 0.0
dp = list(range(n + 1))
for i in range(1, m + 1):
prev, dp[0] = dp[0], i
for j in range(1, n + 1):
tmp = dp[j]
dp[j] = prev if pred[i - 1] == ref[j - 1] else 1 + min(prev, dp[j], dp[j - 1])
prev = tmp
return float(dp[n]) / max(1, m, n)
def _sample(probs: torch.Tensor) -> torch.Tensor:
B, L, V = probs.shape
flat = probs.reshape(B * L, V).clamp(min=1e-9)
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).reshape(B, L)
@torch.no_grad()
def _decode_pred(tgt_tokenizer, out_ids: torch.Tensor) -> str:
ids = [x for x in out_ids[0].tolist() if x > 4]
return tgt_tokenizer.decode(ids).strip()
def _tokenize_ws(text: str) -> list[str]:
return [t for t in text.split() if t]
def _distinct_n(outputs: List[str], n: int = 2) -> float:
ngrams = []
for s in outputs:
toks = _tokenize_ws(s)
if len(toks) < n:
continue
ngrams.extend([tuple(toks[i:i+n]) for i in range(len(toks) - n + 1)])
if not ngrams:
return 0.0
return float(len(set(ngrams)) / max(1, len(ngrams)))
def _self_bleu(outputs: List[str], max_pairs: int = 64) -> float:
try:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
except Exception:
return 0.0
toks = [_tokenize_ws(s) for s in outputs if s.strip()]
if len(toks) < 2:
return 0.0
smooth = SmoothingFunction().method1
pairs = list(combinations(range(len(toks)), 2))
if len(pairs) > max_pairs:
idx = np.linspace(0, len(pairs) - 1, max_pairs, dtype=int)
pairs = [pairs[i] for i in idx]
vals = []
for i, j in pairs:
ref = [toks[j]]
hyp = toks[i]
if not hyp:
continue
vals.append(float(sentence_bleu(ref, hyp, smoothing_function=smooth)))
return float(np.mean(vals)) if vals else 0.0
@torch.no_grad()
def collect_quality_data(
model,
src_list: List[torch.Tensor],
ref_list: List[str],
tgt_tokenizer,
t_capture: int = 0,
max_samples: int = 1000,
) -> tuple[np.ndarray, np.ndarray]:
inner = model.model
device = next(inner.parameters()).device
inner.eval()
hidden_rows = []
quality_rows = []
n = min(max_samples, len(src_list), len(ref_list))
print(f"Collecting quality data from {n} examples...")
for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
if src.dim() == 1:
src = src.unsqueeze(0)
src = src.to(device)
out = inner.generate_cached(src) if hasattr(inner, "generate_cached") else inner.generate(src)
pred = _decode_pred(tgt_tokenizer, out)
cer_q = 1.0 - _cer(pred, ref)
toks = [t for t in pred.split() if t]
uniq = len(set(toks)) / max(1, len(toks))
len_ratio = min(1.0, len(toks) / max(1, len(ref.split())))
# Blend quality target to avoid all-zero collapse on weak checkpoints.
quality = 0.70 * cer_q + 0.20 * uniq + 0.10 * len_ratio
memory, src_pad = inner.encode_source(src)
t = torch.full((1,), int(t_capture), dtype=torch.long, device=device)
_ = inner.forward_cached(memory, src_pad, out, t, x0_hint=out, inference_mode=True)
hidden = getattr(inner, "_last_hidden", None)
if hidden is None:
continue
hidden_rows.append(hidden[0].mean(dim=0).detach().cpu().numpy())
quality_rows.append(float(np.clip(quality, 0.0, 1.0)))
if i % 200 == 0:
print(f" {i}/{n}")
if not hidden_rows:
raise RuntimeError("No hidden states collected for quality classifier.")
hidden_arr = np.asarray(hidden_rows, dtype=np.float32)
quality_arr = np.asarray(quality_rows, dtype=np.float32)
print(f"Collected {hidden_arr.shape[0]} quality examples.")
return hidden_arr, quality_arr
def train_quality_classifier(
hidden: np.ndarray,
quality: np.ndarray,
d_model: int,
epochs: int = 30,
batch_size: int = 64,
lr: float = 1e-3,
save_path: str | None = None,
):
device = torch.device("cpu")
clf = QualityClassifier(d_model).to(device)
x = torch.tensor(hidden, dtype=torch.float32, device=device)
q = quality.astype(np.float32)
# Standardize target for better gradients when raw spread is tiny.
q_mu = float(np.mean(q))
q_sd = float(np.std(q))
if q_sd < 1e-4:
q = q + np.random.normal(0.0, 1e-3, size=q.shape).astype(np.float32)
q_mu = float(np.mean(q))
q_sd = float(np.std(q))
q = np.clip((q - q_mu) / max(q_sd, 1e-6), -3.0, 3.0)
y = torch.tensor(q, dtype=torch.float32, device=device).unsqueeze(-1)
idx = torch.randperm(x.shape[0])
split = int(0.9 * x.shape[0])
tr, va = idx[:split], idx[split:]
x_tr, y_tr = x[tr], y[tr]
x_va, y_va = x[va], y[va]
opt = torch.optim.Adam(clf.parameters(), lr=lr)
loss_fn = nn.MSELoss()
best_val = float("inf")
best_state = None
print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
print(f"Train: {x_tr.shape[0]} Val: {x_va.shape[0]}")
for ep in range(1, epochs + 1):
clf.train()
ep_losses = []
for i in range(0, x_tr.shape[0], batch_size):
xb = x_tr[i : i + batch_size]
yb = y_tr[i : i + batch_size]
pred = clf(xb)
loss = loss_fn(pred, yb)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
ep_losses.append(float(loss.item()))
tr_loss = float(np.mean(ep_losses)) if ep_losses else 0.0
clf.eval()
with torch.no_grad():
va_loss = float(loss_fn(clf(x_va), y_va).item()) if x_va.shape[0] else tr_loss
if va_loss < best_val:
best_val = va_loss
best_state = {k: v.detach().cpu().clone() for k, v in clf.state_dict().items()}
if ep == 1 or ep % 5 == 0 or ep == epochs:
print(f" Ep {ep:>3d} train={tr_loss:.4f} val={va_loss:.4f}")
if best_state is not None:
clf.load_state_dict(best_state)
clf.eval()
print(f" Best val loss: {best_val:.4f}")
if save_path:
torch.save(clf.state_dict(), save_path)
print(f" Classifier saved: {save_path}")
return clf
def generate_guided(
model,
src: torch.Tensor,
classifier: QualityClassifier,
guidance_scale: float = 1.0,
temperature: float = 0.8,
top_k: int = 40,
):
inner = model.model
T = inner.scheduler.num_timesteps
device = next(inner.parameters()).device
if src.dim() == 1:
src = src.unsqueeze(0)
src = src.to(device)
B = src.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
memory, src_pad_mask = inner.encode_source(src)
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
hint = None
inner.eval()
classifier.eval()
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = t_val == 0
with torch.no_grad():
logits, _ = inner.forward_cached(memory, src_pad_mask, x0_est, t, x0_hint=hint, inference_mode=True)
hidden = getattr(inner, "_last_hidden", None)
if guidance_scale > 0.0 and hidden is not None:
hidden_leaf = hidden.detach().requires_grad_(True)
q = classifier(hidden_leaf).sum()
grad = torch.autograd.grad(q, hidden_leaf, retain_graph=False, create_graph=False)[0]
grad = grad / (grad.norm(dim=-1, keepdim=True) + 1e-6)
logit_grad = torch.matmul(grad, inner.head.weight.T)
logits = logits + (1.5 * guidance_scale) * torch.clamp(logit_grad, -6.0, 6.0)
logits = logits / max(float(temperature), 1e-8)
if top_k > 0 and top_k < logits.shape[-1]:
vals, _ = torch.topk(logits, int(top_k), dim=-1)
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
probs = F.softmax(logits, dim=-1)
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
hint = x0_est
return x0_est
def sweep_guidance_scales(
model,
classifier: QualityClassifier,
src_list: List[torch.Tensor],
ref_list: List[str],
tgt_tokenizer,
scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
n_samples: int = 50,
device=None,
output_dir: str = "analysis/outputs",
) -> Dict:
device = device or next(model.parameters()).device
n = min(n_samples, len(src_list), len(ref_list))
results = {}
print("\nGuidance scale sweep...")
for scale in scales:
cer_vals = []
outputs = []
for src, ref in zip(src_list[:n], ref_list[:n]):
# Higher Ξ» gets slightly sharper decoding and stronger signal.
temp = max(0.55, 0.85 - 0.08 * float(scale))
k = max(12, int(40 - 4 * float(scale)))
out = generate_guided(
model, src.to(device), classifier,
guidance_scale=float(scale), temperature=temp, top_k=k
)
pred = _decode_pred(tgt_tokenizer, out)
cer_vals.append(_cer(pred, ref))
outputs.append(pred)
mean_cer = float(np.mean(cer_vals)) if cer_vals else 1.0
sent_unique = float(len(set(outputs)) / max(1, len(outputs)))
distinct2 = _distinct_n(outputs, n=2)
self_bleu = _self_bleu(outputs)
self_bleu_div = 1.0 - self_bleu
diversity = float(0.5 * distinct2 + 0.5 * self_bleu_div)
results[float(scale)] = {
"mean_cer": mean_cer,
"diversity": diversity,
"sent_unique": sent_unique,
"distinct2": distinct2,
"self_bleu": self_bleu,
}
print(
f" Ξ»={float(scale):.1f} CER={mean_cer:.4f} "
f"div={diversity:.3f} d2={distinct2:.3f} sBLEU={self_bleu:.3f}"
)
os.makedirs(output_dir, exist_ok=True)
try:
import matplotlib.pyplot as plt
xs = sorted(results.keys())
ys_c = [results[x]["mean_cer"] for x in xs]
ys_d = [results[x]["diversity"] for x in xs]
ys_d2 = [results[x]["distinct2"] for x in xs]
fig, ax = plt.subplots(1, 3, figsize=(13, 4))
ax[0].plot(xs, ys_c, marker="o")
ax[0].set_xlabel("Guidance scale Ξ»")
ax[0].set_ylabel("CER (lower is better)")
ax[0].set_title("Quality vs Guidance")
ax[1].plot(xs, ys_d, marker="o")
ax[1].set_xlabel("Guidance scale Ξ»")
ax[1].set_ylabel("Composite diversity")
ax[1].set_title("Diversity vs Guidance")
ax[2].plot(xs, ys_d2, marker="o")
ax[2].set_xlabel("Guidance scale Ξ»")
ax[2].set_ylabel("Distinct-2")
ax[2].set_title("Distinct-2 vs Guidance")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
plt.close()
except Exception:
pass
with open(os.path.join(output_dir, "task5_guidance_results.json"), "w", encoding="utf-8") as f:
json.dump({str(k): v for k, v in results.items()}, f, indent=2)
return results
def sweep_guidance(
model,
classifier,
src_list,
ref_list,
tgt_tokenizer,
scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
n_samples=50,
):
results = sweep_guidance_scales(
model=model,
classifier=classifier,
src_list=src_list,
ref_list=ref_list,
tgt_tokenizer=tgt_tokenizer,
scales=scales,
n_samples=n_samples,
output_dir="analysis/outputs",
)
return {
float(k): {"CER": v["mean_cer"], "diversity": v["diversity"]}
for k, v in results.items()
}