rogermt commited on
Commit
da55996
·
verified ·
1 Parent(s): bc9ef03

Upload evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation.py +218 -0
evaluation.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """evaluation.py — Evaluation metrics for NSGF/NSGF++ experiments.
2
+
3
+ Implements:
4
+ - 2-Wasserstein distance (2D experiments)
5
+ - FID (Fréchet Inception Distance) for image experiments
6
+ - IS (Inception Score) for image experiments
7
+ - Visualization utilities
8
+
9
+ Reference: arXiv:2401.14069, Section 5, Appendix E
10
+ """
11
+
12
+ import os
13
+ import logging
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict, Optional, List, Tuple
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def compute_w2_distance(samples: torch.Tensor, targets: torch.Tensor) -> float:
23
+ """Compute 2-Wasserstein distance using POT library."""
24
+ import ot
25
+ x = samples.detach().cpu().numpy()
26
+ y = targets.detach().cpu().numpy()
27
+ M = ot.dist(x, y, metric="sqeuclidean")
28
+ a = np.ones(len(x)) / len(x)
29
+ b = np.ones(len(y)) / len(y)
30
+ w2_sq = ot.emd2(a, b, M)
31
+ return float(np.sqrt(max(w2_sq, 0)))
32
+
33
+
34
+ class InceptionV3Features(nn.Module):
35
+ """Inception V3 wrapper for FID/IS computation."""
36
+ def __init__(self, device: str = "cpu"):
37
+ super().__init__()
38
+ import torchvision.models as models
39
+ self.device = device
40
+ inception = models.inception_v3(pretrained=True, transform_input=False)
41
+ inception.eval()
42
+ self.blocks = nn.Sequential(
43
+ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
44
+ inception.Conv2d_2b_3x3, nn.MaxPool2d(3, stride=2),
45
+ inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
46
+ nn.MaxPool2d(3, stride=2),
47
+ inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
48
+ inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c,
49
+ inception.Mixed_6d, inception.Mixed_6e,
50
+ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
51
+ )
52
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
53
+ self.fc = inception.fc
54
+ self.to(device)
55
+ for p in self.parameters():
56
+ p.requires_grad_(False)
57
+
58
+ @torch.no_grad()
59
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ if x.shape[2] != 299 or x.shape[3] != 299:
61
+ x = torch.nn.functional.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
62
+ if x.shape[1] == 1:
63
+ x = x.repeat(1, 3, 1, 1)
64
+ x = (x + 1) / 2
65
+ h = self.blocks(x)
66
+ features = self.avgpool(h).squeeze(-1).squeeze(-1)
67
+ logits = self.fc(features)
68
+ return features, logits
69
+
70
+
71
+ def compute_fid(generated: torch.Tensor, real: torch.Tensor,
72
+ device: str = "cpu", batch_size: int = 64) -> float:
73
+ from scipy import linalg
74
+ model = InceptionV3Features(device)
75
+ def get_features(images):
76
+ feats = []
77
+ for i in range(0, len(images), batch_size):
78
+ batch = images[i:i + batch_size].to(device)
79
+ f, _ = model(batch)
80
+ feats.append(f.cpu().numpy())
81
+ return np.concatenate(feats, axis=0)
82
+ logger.info("Computing FID: extracting generated features...")
83
+ feats_gen = get_features(generated)
84
+ logger.info("Computing FID: extracting real features...")
85
+ feats_real = get_features(real)
86
+ mu_gen, sigma_gen = feats_gen.mean(0), np.cov(feats_gen, rowvar=False)
87
+ mu_real, sigma_real = feats_real.mean(0), np.cov(feats_real, rowvar=False)
88
+ diff = mu_gen - mu_real
89
+ covmean, _ = linalg.sqrtm(sigma_gen @ sigma_real, disp=False)
90
+ if np.iscomplexobj(covmean):
91
+ covmean = covmean.real
92
+ fid = diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean)
93
+ return float(fid)
94
+
95
+
96
+ def compute_inception_score(images: torch.Tensor, device: str = "cpu",
97
+ batch_size: int = 64, splits: int = 10) -> Tuple[float, float]:
98
+ model = InceptionV3Features(device)
99
+ all_logits = []
100
+ for i in range(0, len(images), batch_size):
101
+ batch = images[i:i + batch_size].to(device)
102
+ _, logits = model(batch)
103
+ all_logits.append(logits.cpu())
104
+ all_logits = torch.cat(all_logits, dim=0)
105
+ probs = torch.softmax(all_logits, dim=1).numpy()
106
+ scores = []
107
+ n = len(probs)
108
+ split_size = n // splits
109
+ for i in range(splits):
110
+ part = probs[i * split_size:(i + 1) * split_size]
111
+ py = part.mean(axis=0, keepdims=True)
112
+ kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10))
113
+ kl = kl.sum(axis=1).mean()
114
+ scores.append(np.exp(kl))
115
+ return float(np.mean(scores)), float(np.std(scores))
116
+
117
+
118
+ class Evaluation:
119
+ def __init__(self, config: dict, device: str = "cpu"):
120
+ self.config = config
121
+ self.device = device
122
+ self.dataset_name = config.get("dataset", "8gaussians")
123
+ self.is_image = self.dataset_name in ("mnist", "cifar10")
124
+
125
+ def evaluate(self, generated: torch.Tensor, real: torch.Tensor) -> Dict[str, float]:
126
+ metrics = {}
127
+ if self.is_image:
128
+ eval_cfg = self.config.get("evaluation", {})
129
+ metric_names = eval_cfg.get("metrics", ["fid"])
130
+ if "fid" in metric_names:
131
+ logger.info("Computing FID...")
132
+ metrics["fid"] = compute_fid(generated, real, self.device)
133
+ logger.info(f"FID: {metrics['fid']:.2f}")
134
+ if "is" in metric_names:
135
+ logger.info("Computing Inception Score...")
136
+ is_mean, is_std = compute_inception_score(generated, self.device)
137
+ metrics["is_mean"] = is_mean
138
+ metrics["is_std"] = is_std
139
+ logger.info(f"IS: {is_mean:.2f} ± {is_std:.2f}")
140
+ else:
141
+ w2 = compute_w2_distance(generated, real)
142
+ metrics["w2"] = w2
143
+ logger.info(f"W2 distance: {w2:.4f}")
144
+ return metrics
145
+
146
+
147
+ def plot_2d_samples(samples: torch.Tensor, targets: torch.Tensor,
148
+ title: str = "Generated vs Target", save_path: Optional[str] = None):
149
+ import matplotlib
150
+ matplotlib.use("Agg")
151
+ import matplotlib.pyplot as plt
152
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
153
+ s = samples.detach().cpu().numpy()
154
+ t = targets.detach().cpu().numpy()
155
+ axes[0].scatter(t[:, 0], t[:, 1], s=3, alpha=0.5, c="blue")
156
+ axes[0].set_title("Target Distribution")
157
+ axes[0].set_xlim(-6, 6); axes[0].set_ylim(-6, 6); axes[0].set_aspect("equal")
158
+ axes[1].scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c="red")
159
+ axes[1].set_title("Generated Samples")
160
+ axes[1].set_xlim(-6, 6); axes[1].set_ylim(-6, 6); axes[1].set_aspect("equal")
161
+ axes[2].scatter(t[:, 0], t[:, 1], s=3, alpha=0.3, c="blue", label="Target")
162
+ axes[2].scatter(s[:, 0], s[:, 1], s=3, alpha=0.3, c="red", label="Generated")
163
+ axes[2].set_title("Overlay")
164
+ axes[2].set_xlim(-6, 6); axes[2].set_ylim(-6, 6); axes[2].set_aspect("equal")
165
+ axes[2].legend()
166
+ plt.suptitle(title)
167
+ plt.tight_layout()
168
+ if save_path:
169
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
170
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
171
+ logger.info(f"Saved plot to {save_path}")
172
+ plt.close()
173
+
174
+
175
+ def plot_2d_trajectory(trajectory: List[torch.Tensor], targets: torch.Tensor,
176
+ title: str = "Flow Trajectory", save_path: Optional[str] = None,
177
+ max_particles: int = 200):
178
+ import matplotlib
179
+ matplotlib.use("Agg")
180
+ import matplotlib.pyplot as plt
181
+ from matplotlib.collections import LineCollection
182
+ fig, ax = plt.subplots(1, 1, figsize=(8, 8))
183
+ t = targets.detach().cpu().numpy()
184
+ ax.scatter(t[:, 0], t[:, 1], s=3, alpha=0.2, c="blue", label="Target")
185
+ T = len(trajectory)
186
+ n = min(trajectory[0].shape[0], max_particles)
187
+ for i in range(n):
188
+ points = np.array([trajectory[step][i].detach().cpu().numpy() for step in range(T)])
189
+ segments = np.array([[points[j], points[j + 1]] for j in range(len(points) - 1)])
190
+ colors = plt.cm.coolwarm(np.linspace(0, 1, len(segments)))
191
+ lc = LineCollection(segments, colors=colors, linewidths=0.5, alpha=0.5)
192
+ ax.add_collection(lc)
193
+ final = trajectory[-1][:n].detach().cpu().numpy()
194
+ ax.scatter(final[:, 0], final[:, 1], s=5, c="red", alpha=0.5, label="Generated")
195
+ ax.set_xlim(-6, 6); ax.set_ylim(-6, 6); ax.set_aspect("equal")
196
+ ax.set_title(title); ax.legend()
197
+ if save_path:
198
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
199
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
200
+ logger.info(f"Saved trajectory plot to {save_path}")
201
+ plt.close()
202
+
203
+
204
+ def plot_image_grid(images: torch.Tensor, nrow: int = 8,
205
+ title: str = "Generated Images", save_path: Optional[str] = None):
206
+ import matplotlib
207
+ matplotlib.use("Agg")
208
+ import matplotlib.pyplot as plt
209
+ import torchvision.utils as vutils
210
+ grid = vutils.make_grid(images[:nrow * nrow], nrow=nrow, normalize=True, value_range=(-1, 1))
211
+ grid_np = grid.permute(1, 2, 0).cpu().numpy()
212
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
213
+ ax.imshow(grid_np); ax.set_title(title); ax.axis("off")
214
+ if save_path:
215
+ os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
216
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
217
+ logger.info(f"Saved image grid to {save_path}")
218
+ plt.close()