| """evaluation.py — Evaluation metrics for NSGF/NSGF++ experiments. |
| |
| Implements: |
| - 2-Wasserstein distance (2D experiments) |
| - FID (Fréchet Inception Distance) for image experiments |
| - IS (Inception Score) for image experiments |
| - Visualization utilities |
| |
| Reference: arXiv:2401.14069, Section 5, Appendix E |
| """ |
|
|
| import os |
| import logging |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from typing import Dict, Optional, List, Tuple |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def compute_w2_distance(samples: torch.Tensor, targets: torch.Tensor) -> float: |
| """Compute 2-Wasserstein distance using POT library.""" |
| import ot |
| x = samples.detach().cpu().numpy() |
| y = targets.detach().cpu().numpy() |
| M = ot.dist(x, y, metric="sqeuclidean") |
| a = np.ones(len(x)) / len(x) |
| b = np.ones(len(y)) / len(y) |
| w2_sq = ot.emd2(a, b, M) |
| return float(np.sqrt(max(w2_sq, 0))) |
|
|
|
|
| class InceptionV3Features(nn.Module): |
| """Inception V3 wrapper for FID/IS computation.""" |
| def __init__(self, device: str = "cpu"): |
| super().__init__() |
| import torchvision.models as models |
| self.device = device |
| inception = models.inception_v3(pretrained=True, transform_input=False) |
| inception.eval() |
| self.blocks = nn.Sequential( |
| inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, |
| inception.Conv2d_2b_3x3, nn.MaxPool2d(3, stride=2), |
| inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, |
| nn.MaxPool2d(3, stride=2), |
| inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d, |
| inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c, |
| inception.Mixed_6d, inception.Mixed_6e, |
| inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, |
| ) |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = inception.fc |
| self.to(device) |
| for p in self.parameters(): |
| p.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| if x.shape[2] != 299 or x.shape[3] != 299: |
| x = torch.nn.functional.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) |
| if x.shape[1] == 1: |
| x = x.repeat(1, 3, 1, 1) |
| x = (x + 1) / 2 |
| h = self.blocks(x) |
| features = self.avgpool(h).squeeze(-1).squeeze(-1) |
| logits = self.fc(features) |
| return features, logits |
|
|
|
|
| def compute_fid(generated: torch.Tensor, real: torch.Tensor, |
| device: str = "cpu", batch_size: int = 64) -> float: |
| from scipy import linalg |
| model = InceptionV3Features(device) |
| def get_features(images): |
| feats = [] |
| for i in range(0, len(images), batch_size): |
| batch = images[i:i + batch_size].to(device) |
| f, _ = model(batch) |
| feats.append(f.cpu().numpy()) |
| return np.concatenate(feats, axis=0) |
| logger.info("Computing FID: extracting generated features...") |
| feats_gen = get_features(generated) |
| logger.info("Computing FID: extracting real features...") |
| feats_real = get_features(real) |
| mu_gen, sigma_gen = feats_gen.mean(0), np.cov(feats_gen, rowvar=False) |
| mu_real, sigma_real = feats_real.mean(0), np.cov(feats_real, rowvar=False) |
| diff = mu_gen - mu_real |
| covmean, _ = linalg.sqrtm(sigma_gen @ sigma_real, disp=False) |
| if np.iscomplexobj(covmean): |
| covmean = covmean.real |
| fid = diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean) |
| return float(fid) |
|
|
|
|
| def compute_inception_score(images: torch.Tensor, device: str = "cpu", |
| batch_size: int = 64, splits: int = 10) -> Tuple[float, float]: |
| model = InceptionV3Features(device) |
| all_logits = [] |
| for i in range(0, len(images), batch_size): |
| batch = images[i:i + batch_size].to(device) |
| _, logits = model(batch) |
| all_logits.append(logits.cpu()) |
| all_logits = torch.cat(all_logits, dim=0) |
| probs = torch.softmax(all_logits, dim=1).numpy() |
| scores = [] |
| n = len(probs) |
| split_size = n // splits |
| for i in range(splits): |
| part = probs[i * split_size:(i + 1) * split_size] |
| py = part.mean(axis=0, keepdims=True) |
| kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10)) |
| kl = kl.sum(axis=1).mean() |
| scores.append(np.exp(kl)) |
| return float(np.mean(scores)), float(np.std(scores)) |
|
|
|
|
| class Evaluation: |
| def __init__(self, config: dict, device: str = "cpu"): |
| self.config = config |
| self.device = device |
| self.dataset_name = config.get("dataset", "8gaussians") |
| self.is_image = self.dataset_name in ("mnist", "cifar10") |
|
|
| def evaluate(self, generated: torch.Tensor, real: torch.Tensor) -> Dict[str, float]: |
| metrics = {} |
| if self.is_image: |
| eval_cfg = self.config.get("evaluation", {}) |
| metric_names = eval_cfg.get("metrics", ["fid"]) |
| if "fid" in metric_names: |
| logger.info("Computing FID...") |
| metrics["fid"] = compute_fid(generated, real, self.device) |
| logger.info(f"FID: {metrics['fid']:.2f}") |
| if "is" in metric_names: |
| logger.info("Computing Inception Score...") |
| is_mean, is_std = compute_inception_score(generated, self.device) |
| metrics["is_mean"] = is_mean |
| metrics["is_std"] = is_std |
| logger.info(f"IS: {is_mean:.2f} ± {is_std:.2f}") |
| else: |
| w2 = compute_w2_distance(generated, real) |
| metrics["w2"] = w2 |
| logger.info(f"W2 distance: {w2:.4f}") |
| return metrics |
|
|
|
|
| def plot_2d_samples(samples: torch.Tensor, targets: torch.Tensor, |
| title: str = "Generated vs Target", save_path: Optional[str] = None): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| s = samples.detach().cpu().numpy() |
| t = targets.detach().cpu().numpy() |
| axes[0].scatter(t[:, 0], t[:, 1], s=3, alpha=0.5, c="blue") |
| axes[0].set_title("Target Distribution") |
| axes[0].set_xlim(-6, 6); axes[0].set_ylim(-6, 6); axes[0].set_aspect("equal") |
| axes[1].scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c="red") |
| axes[1].set_title("Generated Samples") |
| axes[1].set_xlim(-6, 6); axes[1].set_ylim(-6, 6); axes[1].set_aspect("equal") |
| axes[2].scatter(t[:, 0], t[:, 1], s=3, alpha=0.3, c="blue", label="Target") |
| axes[2].scatter(s[:, 0], s[:, 1], s=3, alpha=0.3, c="red", label="Generated") |
| axes[2].set_title("Overlay") |
| axes[2].set_xlim(-6, 6); axes[2].set_ylim(-6, 6); axes[2].set_aspect("equal") |
| axes[2].legend() |
| plt.suptitle(title) |
| plt.tight_layout() |
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| logger.info(f"Saved plot to {save_path}") |
| plt.close() |
|
|
|
|
| def plot_2d_trajectory(trajectory: List[torch.Tensor], targets: torch.Tensor, |
| title: str = "Flow Trajectory", save_path: Optional[str] = None, |
| max_particles: int = 200): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.collections import LineCollection |
| fig, ax = plt.subplots(1, 1, figsize=(8, 8)) |
| t = targets.detach().cpu().numpy() |
| ax.scatter(t[:, 0], t[:, 1], s=3, alpha=0.2, c="blue", label="Target") |
| T = len(trajectory) |
| n = min(trajectory[0].shape[0], max_particles) |
| for i in range(n): |
| points = np.array([trajectory[step][i].detach().cpu().numpy() for step in range(T)]) |
| segments = np.array([[points[j], points[j + 1]] for j in range(len(points) - 1)]) |
| colors = plt.cm.coolwarm(np.linspace(0, 1, len(segments))) |
| lc = LineCollection(segments, colors=colors, linewidths=0.5, alpha=0.5) |
| ax.add_collection(lc) |
| final = trajectory[-1][:n].detach().cpu().numpy() |
| ax.scatter(final[:, 0], final[:, 1], s=5, c="red", alpha=0.5, label="Generated") |
| ax.set_xlim(-6, 6); ax.set_ylim(-6, 6); ax.set_aspect("equal") |
| ax.set_title(title); ax.legend() |
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| logger.info(f"Saved trajectory plot to {save_path}") |
| plt.close() |
|
|
|
|
| def plot_image_grid(images: torch.Tensor, nrow: int = 8, |
| title: str = "Generated Images", save_path: Optional[str] = None): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import torchvision.utils as vutils |
| grid = vutils.make_grid(images[:nrow * nrow], nrow=nrow, normalize=True, value_range=(-1, 1)) |
| grid_np = grid.permute(1, 2, 0).cpu().numpy() |
| fig, ax = plt.subplots(1, 1, figsize=(10, 10)) |
| ax.imshow(grid_np); ax.set_title(title); ax.axis("off") |
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| logger.info(f"Saved image grid to {save_path}") |
| plt.close() |
|
|