Upload evaluation.py
Browse files- evaluation.py +218 -0
evaluation.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""evaluation.py — Evaluation metrics for NSGF/NSGF++ experiments.
|
| 2 |
+
|
| 3 |
+
Implements:
|
| 4 |
+
- 2-Wasserstein distance (2D experiments)
|
| 5 |
+
- FID (Fréchet Inception Distance) for image experiments
|
| 6 |
+
- IS (Inception Score) for image experiments
|
| 7 |
+
- Visualization utilities
|
| 8 |
+
|
| 9 |
+
Reference: arXiv:2401.14069, Section 5, Appendix E
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import logging
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from typing import Dict, Optional, List, Tuple
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_w2_distance(samples: torch.Tensor, targets: torch.Tensor) -> float:
|
| 23 |
+
"""Compute 2-Wasserstein distance using POT library."""
|
| 24 |
+
import ot
|
| 25 |
+
x = samples.detach().cpu().numpy()
|
| 26 |
+
y = targets.detach().cpu().numpy()
|
| 27 |
+
M = ot.dist(x, y, metric="sqeuclidean")
|
| 28 |
+
a = np.ones(len(x)) / len(x)
|
| 29 |
+
b = np.ones(len(y)) / len(y)
|
| 30 |
+
w2_sq = ot.emd2(a, b, M)
|
| 31 |
+
return float(np.sqrt(max(w2_sq, 0)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InceptionV3Features(nn.Module):
|
| 35 |
+
"""Inception V3 wrapper for FID/IS computation."""
|
| 36 |
+
def __init__(self, device: str = "cpu"):
|
| 37 |
+
super().__init__()
|
| 38 |
+
import torchvision.models as models
|
| 39 |
+
self.device = device
|
| 40 |
+
inception = models.inception_v3(pretrained=True, transform_input=False)
|
| 41 |
+
inception.eval()
|
| 42 |
+
self.blocks = nn.Sequential(
|
| 43 |
+
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
|
| 44 |
+
inception.Conv2d_2b_3x3, nn.MaxPool2d(3, stride=2),
|
| 45 |
+
inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
|
| 46 |
+
nn.MaxPool2d(3, stride=2),
|
| 47 |
+
inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
|
| 48 |
+
inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c,
|
| 49 |
+
inception.Mixed_6d, inception.Mixed_6e,
|
| 50 |
+
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
|
| 51 |
+
)
|
| 52 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 53 |
+
self.fc = inception.fc
|
| 54 |
+
self.to(device)
|
| 55 |
+
for p in self.parameters():
|
| 56 |
+
p.requires_grad_(False)
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
| 61 |
+
x = torch.nn.functional.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
|
| 62 |
+
if x.shape[1] == 1:
|
| 63 |
+
x = x.repeat(1, 3, 1, 1)
|
| 64 |
+
x = (x + 1) / 2
|
| 65 |
+
h = self.blocks(x)
|
| 66 |
+
features = self.avgpool(h).squeeze(-1).squeeze(-1)
|
| 67 |
+
logits = self.fc(features)
|
| 68 |
+
return features, logits
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def compute_fid(generated: torch.Tensor, real: torch.Tensor,
|
| 72 |
+
device: str = "cpu", batch_size: int = 64) -> float:
|
| 73 |
+
from scipy import linalg
|
| 74 |
+
model = InceptionV3Features(device)
|
| 75 |
+
def get_features(images):
|
| 76 |
+
feats = []
|
| 77 |
+
for i in range(0, len(images), batch_size):
|
| 78 |
+
batch = images[i:i + batch_size].to(device)
|
| 79 |
+
f, _ = model(batch)
|
| 80 |
+
feats.append(f.cpu().numpy())
|
| 81 |
+
return np.concatenate(feats, axis=0)
|
| 82 |
+
logger.info("Computing FID: extracting generated features...")
|
| 83 |
+
feats_gen = get_features(generated)
|
| 84 |
+
logger.info("Computing FID: extracting real features...")
|
| 85 |
+
feats_real = get_features(real)
|
| 86 |
+
mu_gen, sigma_gen = feats_gen.mean(0), np.cov(feats_gen, rowvar=False)
|
| 87 |
+
mu_real, sigma_real = feats_real.mean(0), np.cov(feats_real, rowvar=False)
|
| 88 |
+
diff = mu_gen - mu_real
|
| 89 |
+
covmean, _ = linalg.sqrtm(sigma_gen @ sigma_real, disp=False)
|
| 90 |
+
if np.iscomplexobj(covmean):
|
| 91 |
+
covmean = covmean.real
|
| 92 |
+
fid = diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean)
|
| 93 |
+
return float(fid)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_inception_score(images: torch.Tensor, device: str = "cpu",
|
| 97 |
+
batch_size: int = 64, splits: int = 10) -> Tuple[float, float]:
|
| 98 |
+
model = InceptionV3Features(device)
|
| 99 |
+
all_logits = []
|
| 100 |
+
for i in range(0, len(images), batch_size):
|
| 101 |
+
batch = images[i:i + batch_size].to(device)
|
| 102 |
+
_, logits = model(batch)
|
| 103 |
+
all_logits.append(logits.cpu())
|
| 104 |
+
all_logits = torch.cat(all_logits, dim=0)
|
| 105 |
+
probs = torch.softmax(all_logits, dim=1).numpy()
|
| 106 |
+
scores = []
|
| 107 |
+
n = len(probs)
|
| 108 |
+
split_size = n // splits
|
| 109 |
+
for i in range(splits):
|
| 110 |
+
part = probs[i * split_size:(i + 1) * split_size]
|
| 111 |
+
py = part.mean(axis=0, keepdims=True)
|
| 112 |
+
kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10))
|
| 113 |
+
kl = kl.sum(axis=1).mean()
|
| 114 |
+
scores.append(np.exp(kl))
|
| 115 |
+
return float(np.mean(scores)), float(np.std(scores))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Evaluation:
|
| 119 |
+
def __init__(self, config: dict, device: str = "cpu"):
|
| 120 |
+
self.config = config
|
| 121 |
+
self.device = device
|
| 122 |
+
self.dataset_name = config.get("dataset", "8gaussians")
|
| 123 |
+
self.is_image = self.dataset_name in ("mnist", "cifar10")
|
| 124 |
+
|
| 125 |
+
def evaluate(self, generated: torch.Tensor, real: torch.Tensor) -> Dict[str, float]:
|
| 126 |
+
metrics = {}
|
| 127 |
+
if self.is_image:
|
| 128 |
+
eval_cfg = self.config.get("evaluation", {})
|
| 129 |
+
metric_names = eval_cfg.get("metrics", ["fid"])
|
| 130 |
+
if "fid" in metric_names:
|
| 131 |
+
logger.info("Computing FID...")
|
| 132 |
+
metrics["fid"] = compute_fid(generated, real, self.device)
|
| 133 |
+
logger.info(f"FID: {metrics['fid']:.2f}")
|
| 134 |
+
if "is" in metric_names:
|
| 135 |
+
logger.info("Computing Inception Score...")
|
| 136 |
+
is_mean, is_std = compute_inception_score(generated, self.device)
|
| 137 |
+
metrics["is_mean"] = is_mean
|
| 138 |
+
metrics["is_std"] = is_std
|
| 139 |
+
logger.info(f"IS: {is_mean:.2f} ± {is_std:.2f}")
|
| 140 |
+
else:
|
| 141 |
+
w2 = compute_w2_distance(generated, real)
|
| 142 |
+
metrics["w2"] = w2
|
| 143 |
+
logger.info(f"W2 distance: {w2:.4f}")
|
| 144 |
+
return metrics
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def plot_2d_samples(samples: torch.Tensor, targets: torch.Tensor,
|
| 148 |
+
title: str = "Generated vs Target", save_path: Optional[str] = None):
|
| 149 |
+
import matplotlib
|
| 150 |
+
matplotlib.use("Agg")
|
| 151 |
+
import matplotlib.pyplot as plt
|
| 152 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 153 |
+
s = samples.detach().cpu().numpy()
|
| 154 |
+
t = targets.detach().cpu().numpy()
|
| 155 |
+
axes[0].scatter(t[:, 0], t[:, 1], s=3, alpha=0.5, c="blue")
|
| 156 |
+
axes[0].set_title("Target Distribution")
|
| 157 |
+
axes[0].set_xlim(-6, 6); axes[0].set_ylim(-6, 6); axes[0].set_aspect("equal")
|
| 158 |
+
axes[1].scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c="red")
|
| 159 |
+
axes[1].set_title("Generated Samples")
|
| 160 |
+
axes[1].set_xlim(-6, 6); axes[1].set_ylim(-6, 6); axes[1].set_aspect("equal")
|
| 161 |
+
axes[2].scatter(t[:, 0], t[:, 1], s=3, alpha=0.3, c="blue", label="Target")
|
| 162 |
+
axes[2].scatter(s[:, 0], s[:, 1], s=3, alpha=0.3, c="red", label="Generated")
|
| 163 |
+
axes[2].set_title("Overlay")
|
| 164 |
+
axes[2].set_xlim(-6, 6); axes[2].set_ylim(-6, 6); axes[2].set_aspect("equal")
|
| 165 |
+
axes[2].legend()
|
| 166 |
+
plt.suptitle(title)
|
| 167 |
+
plt.tight_layout()
|
| 168 |
+
if save_path:
|
| 169 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 170 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 171 |
+
logger.info(f"Saved plot to {save_path}")
|
| 172 |
+
plt.close()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def plot_2d_trajectory(trajectory: List[torch.Tensor], targets: torch.Tensor,
|
| 176 |
+
title: str = "Flow Trajectory", save_path: Optional[str] = None,
|
| 177 |
+
max_particles: int = 200):
|
| 178 |
+
import matplotlib
|
| 179 |
+
matplotlib.use("Agg")
|
| 180 |
+
import matplotlib.pyplot as plt
|
| 181 |
+
from matplotlib.collections import LineCollection
|
| 182 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
|
| 183 |
+
t = targets.detach().cpu().numpy()
|
| 184 |
+
ax.scatter(t[:, 0], t[:, 1], s=3, alpha=0.2, c="blue", label="Target")
|
| 185 |
+
T = len(trajectory)
|
| 186 |
+
n = min(trajectory[0].shape[0], max_particles)
|
| 187 |
+
for i in range(n):
|
| 188 |
+
points = np.array([trajectory[step][i].detach().cpu().numpy() for step in range(T)])
|
| 189 |
+
segments = np.array([[points[j], points[j + 1]] for j in range(len(points) - 1)])
|
| 190 |
+
colors = plt.cm.coolwarm(np.linspace(0, 1, len(segments)))
|
| 191 |
+
lc = LineCollection(segments, colors=colors, linewidths=0.5, alpha=0.5)
|
| 192 |
+
ax.add_collection(lc)
|
| 193 |
+
final = trajectory[-1][:n].detach().cpu().numpy()
|
| 194 |
+
ax.scatter(final[:, 0], final[:, 1], s=5, c="red", alpha=0.5, label="Generated")
|
| 195 |
+
ax.set_xlim(-6, 6); ax.set_ylim(-6, 6); ax.set_aspect("equal")
|
| 196 |
+
ax.set_title(title); ax.legend()
|
| 197 |
+
if save_path:
|
| 198 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 199 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 200 |
+
logger.info(f"Saved trajectory plot to {save_path}")
|
| 201 |
+
plt.close()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def plot_image_grid(images: torch.Tensor, nrow: int = 8,
|
| 205 |
+
title: str = "Generated Images", save_path: Optional[str] = None):
|
| 206 |
+
import matplotlib
|
| 207 |
+
matplotlib.use("Agg")
|
| 208 |
+
import matplotlib.pyplot as plt
|
| 209 |
+
import torchvision.utils as vutils
|
| 210 |
+
grid = vutils.make_grid(images[:nrow * nrow], nrow=nrow, normalize=True, value_range=(-1, 1))
|
| 211 |
+
grid_np = grid.permute(1, 2, 0).cpu().numpy()
|
| 212 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
| 213 |
+
ax.imshow(grid_np); ax.set_title(title); ax.axis("off")
|
| 214 |
+
if save_path:
|
| 215 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 216 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 217 |
+
logger.info(f"Saved image grid to {save_path}")
|
| 218 |
+
plt.close()
|