| import torch |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import lightning as L |
| from lightning.pytorch.callbacks import Callback |
| from lightning.pytorch.loggers import WandbLogger |
| from typing import Any, Dict, Optional |
|
|
|
|
| class VisualizationCallback(Callback): |
| """ |
| Callback to visualize spectrograms, patches, and masks. |
| Logs the first 4 samples of the first 2 batches. |
| """ |
|
|
| def __init__(self, num_samples: int = 4): |
| super().__init__() |
| self.num_samples = num_samples |
| self.batches_logged = 0 |
|
|
| def on_train_batch_end( |
| self, |
| trainer: L.Trainer, |
| pl_module: L.LightningModule, |
| outputs: Any, |
| batch: Any, |
| batch_idx: int, |
| ) -> None: |
| if self.batches_logged >= 2: |
| return |
|
|
| |
| if batch_idx < 2: |
| self._log_visualizations(trainer, pl_module, batch, batch_idx) |
| self.batches_logged += 1 |
|
|
| def _log_visualizations( |
| self, |
| trainer: L.Trainer, |
| pl_module: L.LightningModule, |
| batch: Dict[str, Any], |
| batch_idx: int, |
| ) -> None: |
| logger = trainer.logger |
| if not isinstance(logger, WandbLogger): |
| return |
|
|
| waveform = batch["waveform"][: self.num_samples] |
|
|
| sample_rate = self._resolve_sample_rate(trainer, pl_module) |
|
|
| |
| with torch.no_grad(): |
| spec = pl_module.spectrogram(waveform.to(pl_module.device)) |
|
|
| |
| patch_size = pl_module.patch_embed.patch_embed.patch_size |
| F_pix = spec.shape[2] |
| T_pix = spec.shape[3] |
| H_grid = F_pix // patch_size[0] |
| W_grid = T_pix // patch_size[1] |
| current_grid_size = (H_grid, W_grid) |
|
|
| |
| |
| |
| mask = pl_module.mask_generator( |
| 1, device=pl_module.device, grid_size=current_grid_size |
| ) |
| mask = mask.expand(self.num_samples, -1) |
|
|
| |
| import wandb |
|
|
| columns = [ |
| "Batch Idx", |
| "Sample Idx", |
| "Audio", |
| "Spectrogram", |
| "Masked Spectrogram (Context)", |
| "Inverse Masked Spectrogram (Targets)", |
| ] |
| data = [] |
|
|
| for i in range(self.num_samples): |
| |
| audio_data = waveform[i].squeeze().cpu().numpy() |
| audio = wandb.Audio( |
| audio_data, sample_rate=sample_rate, caption=f"B{batch_idx}_S{i}" |
| ) |
|
|
| |
| spec_data = spec[i].squeeze().cpu().numpy() |
| mask_data = mask[i].cpu().numpy() |
|
|
| |
| fig_orig = self._plot_spectrogram(spec_data, patch_size, current_grid_size) |
| img_orig = wandb.Image(fig_orig, caption=f"Spec B{batch_idx}_S{i}") |
| plt.close(fig_orig) |
|
|
| |
| fig_masked = self._plot_spectrogram_with_mask( |
| spec_data, mask_data, patch_size, current_grid_size, invert_mask=False |
| ) |
| img_masked = wandb.Image(fig_masked, caption=f"Masked B{batch_idx}_S{i}") |
| plt.close(fig_masked) |
|
|
| |
| fig_inv_masked = self._plot_spectrogram_with_mask( |
| spec_data, mask_data, patch_size, current_grid_size, invert_mask=True |
| ) |
| img_inv_masked = wandb.Image( |
| fig_inv_masked, caption=f"InvMasked B{batch_idx}_S{i}" |
| ) |
| plt.close(fig_inv_masked) |
|
|
| data.append([batch_idx, i, audio, img_orig, img_masked, img_inv_masked]) |
|
|
| |
| table = wandb.Table(columns=columns, data=data) |
| logger.experiment.log({f"train/visualizations_batch_{batch_idx}": table}) |
|
|
| @staticmethod |
| def _resolve_sample_rate(trainer: L.Trainer, pl_module: L.LightningModule) -> int: |
| """Resolve audio logging sample rate, preferring data target sample rate.""" |
| sample_rate = 32000 |
|
|
| datamodule = getattr(trainer, "datamodule", None) |
| if datamodule is not None: |
| dm_sr = getattr(datamodule, "target_sample_rate", None) |
| if dm_sr is None and hasattr(datamodule, "hparams"): |
| hparams = datamodule.hparams |
| if isinstance(hparams, dict): |
| dm_sr = hparams.get("target_sample_rate") |
| else: |
| dm_sr = getattr(hparams, "target_sample_rate", None) |
|
|
| if dm_sr is not None: |
| return int(dm_sr) |
|
|
| spectrogram = getattr(pl_module, "spectrogram", None) |
| module_sr = getattr(spectrogram, "sample_rate", None) |
| if module_sr is not None: |
| return int(module_sr) |
|
|
| hparams = getattr(pl_module, "hparams", None) |
| if isinstance(hparams, dict): |
| net_cfg = hparams.get("net") |
| if isinstance(net_cfg, dict): |
| spectrogram_cfg = net_cfg.get("spectrogram") |
| if isinstance(spectrogram_cfg, dict): |
| config_sr = spectrogram_cfg.get("sample_rate") |
| if config_sr is not None: |
| return int(config_sr) |
|
|
| return sample_rate |
|
|
| def _plot_spectrogram( |
| self, spec: np.ndarray, patch_size: tuple[int, int], grid_size: tuple[int, int] |
| ) -> plt.Figure: |
| """Plots spectrogram with grid lines.""" |
| return self._plot_spectrogram_with_mask(spec, None, patch_size, grid_size) |
|
|
| def _plot_spectrogram_with_mask( |
| self, |
| spec: np.ndarray, |
| mask: Optional[np.ndarray], |
| patch_size: tuple[int, int], |
| grid_size: tuple[int, int], |
| invert_mask: bool = False, |
| ) -> plt.Figure: |
| """ |
| Plots spectrogram with dashed grid lines and darker masked patches. |
| If mask is None, just plots spectrogram and grid. |
| If invert_mask is True, darkens the unmasked parts instead. |
| """ |
| H_grid, W_grid = grid_size |
| Ph, Pw = patch_size |
| H, W = spec.shape |
|
|
| fig, ax = plt.subplots(figsize=(10, 4)) |
| ax.imshow(spec, origin="lower", aspect="auto", cmap="viridis") |
|
|
| |
| for h in range(0, H + 1, Ph): |
| ax.axhline(h - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5) |
| for w in range(0, W + 1, Pw): |
| ax.axvline(w - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5) |
|
|
| |
| if mask is not None: |
| mask_grid = mask.reshape(H_grid, W_grid) |
| if invert_mask: |
| mask_grid = ~mask_grid |
|
|
| overlay = np.zeros((H, W, 4)) |
| for r in range(H_grid): |
| for c in range(W_grid): |
| if mask_grid[r, c]: |
| y_start = r * Ph |
| y_end = (r + 1) * Ph |
| x_start = c * Pw |
| x_end = (c + 1) * Pw |
| overlay[y_start:y_end, x_start:x_end, 3] = 0.7 |
|
|
| ax.imshow(overlay, origin="lower", aspect="auto") |
|
|
| ax.set_title("Spectrogram") |
| ax.set_xlabel("Time Frames") |
| ax.set_ylabel("Frequency Bins") |
| plt.tight_layout() |
| return fig |
|
|