Spaces:
Running
Running
| """Validation callback for training loop monitoring. | |
| Periodically generates sample images from the validation set, computes | |
| metrics (SSIM, LPIPS, NME, identity similarity), and logs results | |
| to WandB and/or disk. | |
| Designed for use with train_controlnet.py — call at regular intervals | |
| during training to monitor quality without disrupting the training loop. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from landmarkdiff.evaluation import compute_lpips, compute_ssim | |
| class ValidationCallback: | |
| """Validation callback that generates and evaluates samples during training. | |
| Usage:: | |
| val_cb = ValidationCallback( | |
| val_dataset=val_dataset, | |
| output_dir=Path("checkpoints/val"), | |
| num_samples=8, | |
| samples_per_procedure=2, | |
| ) | |
| # In training loop: | |
| if global_step % val_every == 0: | |
| val_metrics = val_cb.run( | |
| controlnet=ema_controlnet, | |
| vae=vae, | |
| unet=unet, | |
| text_embeddings=text_embeddings, | |
| noise_scheduler=noise_scheduler, | |
| device=device, | |
| weight_dtype=weight_dtype, | |
| global_step=global_step, | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| val_dataset, | |
| output_dir: Path, | |
| num_samples: int = 8, | |
| num_inference_steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| samples_per_procedure: int = 2, | |
| ): | |
| self.val_dataset = val_dataset | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.num_samples = min(num_samples, len(val_dataset)) | |
| self.num_inference_steps = num_inference_steps | |
| self.guidance_scale = guidance_scale | |
| self.samples_per_procedure = samples_per_procedure | |
| self.history: list[dict] = [] | |
| # Pre-build per-procedure index map for stratified sampling | |
| self._procedure_indices = self._build_procedure_map() | |
| def _build_procedure_map(self) -> dict[str, list[int]]: | |
| """Build a mapping of procedure name to dataset indices.""" | |
| from collections import defaultdict | |
| proc_indices: dict[str, list[int]] = defaultdict(list) | |
| ds = self.val_dataset | |
| if hasattr(ds, "_sample_procedures") and ds._sample_procedures: | |
| for idx, pair_path in enumerate(ds.pairs): | |
| prefix = pair_path.stem.replace("_input", "") | |
| proc = ds._sample_procedures.get(prefix, "unknown") | |
| proc_indices[proc].append(idx) | |
| elif hasattr(ds, "get_procedure"): | |
| for idx in range(len(ds)): | |
| proc = ds.get_procedure(idx) | |
| proc_indices[proc].append(idx) | |
| # Drop "unknown" if we have labeled procedures | |
| known = {k: v for k, v in proc_indices.items() if k != "unknown"} | |
| return dict(known) if known else dict(proc_indices) | |
| def _select_per_procedure_indices(self) -> list[tuple[int, str]]: | |
| """Select sample indices ensuring each procedure is represented. | |
| Returns list of (dataset_index, procedure_name) tuples. | |
| Falls back to first N sequential indices when no procedure metadata | |
| is available. | |
| """ | |
| if not self._procedure_indices: | |
| return [(i, "unknown") for i in range(self.num_samples)] | |
| selected: list[tuple[int, str]] = [] | |
| for proc, indices in sorted(self._procedure_indices.items()): | |
| for idx in indices[: self.samples_per_procedure]: | |
| selected.append((idx, proc)) | |
| return selected | |
| def run( | |
| self, | |
| controlnet: torch.nn.Module, | |
| vae, | |
| unet, | |
| text_embeddings: torch.Tensor, | |
| noise_scheduler, | |
| device: torch.device, | |
| weight_dtype: torch.dtype, | |
| global_step: int, | |
| ) -> dict: | |
| """Run validation: generate samples and compute metrics. | |
| Returns dict with aggregate and per-procedure metrics. | |
| """ | |
| from diffusers import DDIMScheduler | |
| t0 = time.time() | |
| controlnet.eval() | |
| step_dir = self.output_dir / f"step-{global_step}" | |
| step_dir.mkdir(parents=True, exist_ok=True) | |
| # Set up inference scheduler (DDIM for robustness during validation) | |
| scheduler = DDIMScheduler.from_config(noise_scheduler.config) | |
| scheduler.set_timesteps(self.num_inference_steps, device=device) | |
| ssim_scores = [] | |
| lpips_scores = [] | |
| generated_images = [] | |
| # Per-procedure metric accumulators | |
| proc_ssim: dict[str, list[float]] = {} | |
| proc_lpips: dict[str, list[float]] = {} | |
| # Use per-procedure selection instead of sequential indices | |
| per_proc = self._select_per_procedure_indices() | |
| for sample_num, (idx, proc) in enumerate(per_proc): | |
| sample = self.val_dataset[idx] | |
| conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype) | |
| target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype) | |
| # Encode target for latent shape (VAE needs float32) | |
| latents = vae.encode((target * 2 - 1).float()).latent_dist.sample() | |
| latents = (latents * vae.config.scaling_factor).to(weight_dtype) | |
| # Start from noise | |
| noise = torch.randn_like(latents) | |
| sample_latents = noise * scheduler.init_noise_sigma | |
| encoder_hidden_states = text_embeddings[:1] | |
| # Denoising loop with autocast to handle BF16/FP32 dtype | |
| # mismatches in timestep embeddings | |
| with torch.autocast("cuda", dtype=weight_dtype): | |
| for t in scheduler.timesteps: | |
| scaled = scheduler.scale_model_input(sample_latents, t) | |
| # ControlNet | |
| down_samples, mid_sample = controlnet( | |
| scaled, t, encoder_hidden_states=encoder_hidden_states, | |
| controlnet_cond=conditioning, return_dict=False, | |
| ) | |
| # UNet with ControlNet residuals | |
| noise_pred = unet( | |
| scaled, t, encoder_hidden_states=encoder_hidden_states, | |
| down_block_additional_residuals=down_samples, | |
| mid_block_additional_residual=mid_sample, | |
| ).sample | |
| sample_latents = scheduler.step( | |
| noise_pred, t, sample_latents, | |
| ).prev_sample | |
| # Decode -- cast VAE to float32 temporarily to avoid color banding | |
| # and prevent dtype mismatch (latents float32 vs VAE weights bf16) | |
| vae_dtype = next(vae.parameters()).dtype | |
| vae.to(torch.float32) | |
| decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample | |
| vae.to(vae_dtype) | |
| decoded = ((decoded + 1) / 2).clamp(0, 1) | |
| # Convert to numpy for metrics | |
| gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
| # BGR for metrics (our metrics expect BGR) | |
| gen_bgr = gen_np[:, :, ::-1].copy() | |
| tgt_bgr = tgt_np[:, :, ::-1].copy() | |
| # Compute metrics | |
| ssim_val = compute_ssim(gen_bgr, tgt_bgr) | |
| lpips_val = compute_lpips(gen_bgr, tgt_bgr) | |
| ssim_scores.append(ssim_val) | |
| lpips_scores.append(lpips_val) | |
| generated_images.append(gen_np) | |
| # Accumulate per-procedure metrics | |
| proc_ssim.setdefault(proc, []).append(ssim_val) | |
| proc_lpips.setdefault(proc, []).append(lpips_val) | |
| # Save comparison: conditioning | generated | target | |
| proc_tag = proc.replace(" ", "_") | |
| comparison = np.hstack([cond_np, gen_np, tgt_np]) | |
| Image.fromarray(comparison).save( | |
| step_dir / f"val_{sample_num:02d}_{proc_tag}.png" | |
| ) | |
| # Aggregate metrics | |
| metrics: dict = { | |
| "step": global_step, | |
| "ssim_mean": float(np.nanmean(ssim_scores)), | |
| "ssim_std": float(np.nanstd(ssim_scores)), | |
| "lpips_mean": float(np.nanmean(lpips_scores)), | |
| "lpips_std": float(np.nanstd(lpips_scores)), | |
| "time_seconds": round(time.time() - t0, 1), | |
| } | |
| # Per-procedure breakdown | |
| per_procedure: dict[str, dict] = {} | |
| for proc in sorted(proc_ssim.keys()): | |
| per_procedure[proc] = { | |
| "ssim_mean": float(np.nanmean(proc_ssim[proc])), | |
| "lpips_mean": float(np.nanmean(proc_lpips[proc])), | |
| "n_samples": len(proc_ssim[proc]), | |
| } | |
| metrics["per_procedure"] = per_procedure | |
| self.history.append(metrics) | |
| # Save metrics | |
| with open(step_dir / "metrics.json", "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| # Save full history | |
| with open(self.output_dir / "validation_history.json", "w") as f: | |
| json.dump(self.history, f, indent=2) | |
| # Create comparison grid (all samples in one image) | |
| if generated_images: | |
| grid_rows = [] | |
| for i in range(0, len(generated_images), 4): | |
| row_imgs = generated_images[i:i + 4] | |
| while len(row_imgs) < 4: | |
| row_imgs.append(np.zeros_like(generated_images[0])) | |
| grid_rows.append(np.hstack(row_imgs)) | |
| grid = np.vstack(grid_rows) | |
| Image.fromarray(grid).save(step_dir / "grid.png") | |
| controlnet.train() | |
| # Log summary with per-procedure breakdown | |
| proc_summary = " | ".join( | |
| f"{p}: SSIM={v['ssim_mean']:.3f}" | |
| for p, v in sorted(per_procedure.items()) | |
| ) | |
| print( | |
| f" Validation @ step {global_step}: " | |
| f"SSIM={metrics['ssim_mean']:.4f}+/-{metrics['ssim_std']:.4f} " | |
| f"LPIPS={metrics['lpips_mean']:.4f}+/-{metrics['lpips_std']:.4f} " | |
| f"({metrics['time_seconds']:.1f}s)" | |
| ) | |
| if proc_summary: | |
| print(f" Per-procedure: {proc_summary}") | |
| return metrics | |
| def plot_history(self, output_path: str | None = None) -> None: | |
| """Plot validation metrics over training steps.""" | |
| if not self.history: | |
| return | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| return | |
| steps = [h["step"] for h in self.history] | |
| ssim = [h["ssim_mean"] for h in self.history] | |
| lpips = [h["lpips_mean"] for h in self.history] | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax1.plot(steps, ssim, "b-o", markersize=4) | |
| ax1.set_xlabel("Training Step") | |
| ax1.set_ylabel("SSIM") | |
| ax1.set_title("Validation SSIM (higher=better)") | |
| ax1.grid(alpha=0.3) | |
| ax2.plot(steps, lpips, "r-o", markersize=4) | |
| ax2.set_xlabel("Training Step") | |
| ax2.set_ylabel("LPIPS") | |
| ax2.set_title("Validation LPIPS (lower=better)") | |
| ax2.grid(alpha=0.3) | |
| plt.tight_layout() | |
| path = output_path or str(self.output_dir / "validation_curves.png") | |
| plt.savefig(path, dpi=150, bbox_inches="tight") | |
| plt.close() | |