# Visualization code from https://github.com/Tsingularity/dift/blob/main/src/utils/visualization.py import io from pathlib import Path import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from einops import rearrange from PIL import Image FONT_SIZE = 40 @torch.no_grad() def plot_feats( image, target, pred, legend=["Image", "HR Features", "Pred Features"], save_path=None, return_array=False, show_legend=True, font_size=FONT_SIZE, ): """ Create a plot_feats visualization. """ # Ensure hr_or_seg is a list if not isinstance(pred, list): pred = [pred] # Prepare inputs for PCA feats_for_pca = [target.unsqueeze(0)] + [_.unsqueeze(0) for _ in pred] reduced_feats, _ = pca(feats_for_pca) # pca outputs a list of reduced tensors target_imgs = reduced_feats[0] pred_imgs = reduced_feats[1:] # --- Plot --- # Determine number of columns based on whether image is provided n_cols = (1 if image is not None else 0) + 1 + len(pred_imgs) fig, ax = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5)) # Reduce space between images plt.subplots_adjust(wspace=0.05, hspace=0.05) # Handle single subplot case if n_cols == 1: ax = [ax] # Current axis index ax_idx = 0 # Plot original image if provided if image is not None: if image.dim() == 3: ax[ax_idx].imshow(image.permute(1, 2, 0).detach().cpu()) elif image.dim() == 2: ax[ax_idx].imshow(image.detach().cpu(), cmap="inferno") if show_legend: ax[ax_idx].set_title(legend[0], fontsize=font_size) ax_idx += 1 # Plot the low-resolution features or segmentation mask ax[ax_idx].imshow(target_imgs[0].permute(1, 2, 0).detach().cpu()) if show_legend: legend_idx = 1 if image is not None else 0 ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size) ax_idx += 1 # Plot HR features or segmentation masks for idx, pred_img in enumerate(pred_imgs): ax[ax_idx].imshow(pred_img[0].permute(1, 2, 0).detach().cpu()) if show_legend: legend_idx = (2 if image is not None else 1) + idx if len(legend) > legend_idx: ax[ax_idx].set_title(legend[legend_idx], fontsize=font_size) else: ax[ax_idx].set_title(f"HR Features {idx}", fontsize=font_size) ax_idx += 1 remove_axes(ax) # Handle return_array case if return_array: # Turn off interactive mode temporarily was_interactive = plt.isinteractive() plt.ioff() # Convert figure to numpy array buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) buf.seek(0) # Convert to PIL Image then to numpy array pil_img = Image.open(buf) img_array = np.array(pil_img) # Close the figure and buffer plt.close(fig) buf.close() # Restore interactive mode if it was on if was_interactive: plt.ion() return img_array # Standard behavior: save and/or show if save_path is not None: plt.savefig(save_path, bbox_inches="tight", pad_inches=0) plt.show() return None def remove_axes(axes): def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None): target_size = None if len(image_feats_list) > 1 and fit_pca is None: target_size = image_feats_list[0].shape[2] def flatten(tensor, target_size=None): B, C, H, W = tensor.shape assert B == 1, "Batch size should be 1 for PCA flattening" if target_size is not None: tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear", align_corners=False) return rearrange(tensor, "b c h w -> (b h w) c").detach().cpu() flattened_feats = [] for feats in image_feats_list: flattened_feats.append(flatten(feats, target_size)) x = torch.cat(flattened_feats, dim=0) # Subsample the data if max_samples is set and the number of samples exceeds max_samples if max_samples is not None and x.shape[0] > max_samples: indices = torch.randperm(x.shape[0])[:max_samples] x = x[indices] if fit_pca is None: fit_pca = TorchPCA(n_components=dim).fit(x) reduced_feats = [] for feats in image_feats_list: B, C, H, W = feats.shape x_red = fit_pca.transform(flatten(feats)) if isinstance(x_red, np.ndarray): x_red = torch.from_numpy(x_red) x_red -= x_red.min(dim=0, keepdim=True).values x_red /= x_red.max(dim=0, keepdim=True).values reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2)) return reduced_feats, fit_pca class TorchPCA(object): def __init__(self, n_components, skip=0): self.n_components = n_components self.skip = skip def fit(self, X): self.mean_ = X.mean(dim=0) unbiased = X - self.mean_ U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=20) self.components_ = V[:, self.skip :] self.singular_values_ = S return self def transform(self, X): t0 = X - self.mean_.unsqueeze(0) projected = t0 @ self.components_ return projected