dreamlessx commited on
Commit
a681eea
·
verified ·
1 Parent(s): 0810860

Upload landmarkdiff/metrics_agg.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/metrics_agg.py +308 -0
landmarkdiff/metrics_agg.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metrics aggregation across checkpoints, experiments, and procedures.
2
+
3
+ Collects evaluation results from multiple sources and computes aggregate
4
+ statistics, confidence intervals, and significance tests for paper reporting.
5
+
6
+ Usage:
7
+ from landmarkdiff.metrics_agg import MetricsAggregator
8
+
9
+ agg = MetricsAggregator()
10
+ agg.add("baseline", "rhinoplasty", {"ssim": 0.82, "lpips": 0.18})
11
+ agg.add("ours", "rhinoplasty", {"ssim": 0.91, "lpips": 0.09})
12
+ print(agg.summary_table())
13
+ print(agg.improvement_over("baseline"))
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import math
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+
25
+ @dataclass
26
+ class MetricRecord:
27
+ """A single evaluation record."""
28
+
29
+ experiment: str
30
+ procedure: str
31
+ metrics: dict[str, float]
32
+ checkpoint_step: int | None = None
33
+ metadata: dict[str, Any] = field(default_factory=dict)
34
+
35
+
36
+ class MetricsAggregator:
37
+ """Aggregate and analyze evaluation metrics.
38
+
39
+ Supports multiple experiments, procedures, and per-sample results
40
+ for computing confidence intervals and significance.
41
+ """
42
+
43
+ HIGHER_BETTER = {
44
+ "ssim": True, "psnr": True, "identity_sim": True,
45
+ "lpips": False, "fid": False, "nme": False,
46
+ }
47
+
48
+ def __init__(self) -> None:
49
+ self.records: list[MetricRecord] = []
50
+
51
+ def add(
52
+ self,
53
+ experiment: str,
54
+ procedure: str,
55
+ metrics: dict[str, float],
56
+ checkpoint_step: int | None = None,
57
+ **metadata: Any,
58
+ ) -> None:
59
+ """Add a single evaluation record."""
60
+ self.records.append(MetricRecord(
61
+ experiment=experiment,
62
+ procedure=procedure,
63
+ metrics=metrics,
64
+ checkpoint_step=checkpoint_step,
65
+ metadata=metadata,
66
+ ))
67
+
68
+ def add_batch(
69
+ self,
70
+ experiment: str,
71
+ records: list[dict[str, Any]],
72
+ ) -> None:
73
+ """Add multiple records for an experiment.
74
+
75
+ Each record dict should have 'procedure' and metric keys.
76
+ """
77
+ for rec in records:
78
+ proc = rec.get("procedure", "all")
79
+ metrics = {k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))}
80
+ self.add(experiment, proc, metrics)
81
+
82
+ @property
83
+ def experiments(self) -> list[str]:
84
+ """Unique experiment names in insertion order."""
85
+ seen: dict[str, None] = {}
86
+ for r in self.records:
87
+ seen.setdefault(r.experiment, None)
88
+ return list(seen.keys())
89
+
90
+ @property
91
+ def procedures(self) -> list[str]:
92
+ """Unique procedure names in insertion order."""
93
+ seen: dict[str, None] = {}
94
+ for r in self.records:
95
+ seen.setdefault(r.procedure, None)
96
+ return list(seen.keys())
97
+
98
+ @property
99
+ def metric_names(self) -> list[str]:
100
+ """All unique metric names."""
101
+ names: set[str] = set()
102
+ for r in self.records:
103
+ names.update(r.metrics.keys())
104
+ return sorted(names)
105
+
106
+ def filter(
107
+ self,
108
+ experiment: str | None = None,
109
+ procedure: str | None = None,
110
+ ) -> list[MetricRecord]:
111
+ """Filter records by experiment and/or procedure."""
112
+ results = self.records
113
+ if experiment is not None:
114
+ results = [r for r in results if r.experiment == experiment]
115
+ if procedure is not None:
116
+ results = [r for r in results if r.procedure == procedure]
117
+ return results
118
+
119
+ def mean(
120
+ self,
121
+ experiment: str,
122
+ metric: str,
123
+ procedure: str | None = None,
124
+ ) -> float:
125
+ """Compute mean of a metric for an experiment."""
126
+ recs = self.filter(experiment=experiment, procedure=procedure)
127
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
128
+ if not vals:
129
+ return float("nan")
130
+ return sum(vals) / len(vals)
131
+
132
+ def std(
133
+ self,
134
+ experiment: str,
135
+ metric: str,
136
+ procedure: str | None = None,
137
+ ) -> float:
138
+ """Compute standard deviation of a metric."""
139
+ recs = self.filter(experiment=experiment, procedure=procedure)
140
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
141
+ if len(vals) < 2:
142
+ return 0.0
143
+ m = sum(vals) / len(vals)
144
+ var = sum((v - m) ** 2 for v in vals) / (len(vals) - 1)
145
+ return math.sqrt(var)
146
+
147
+ def ci_95(
148
+ self,
149
+ experiment: str,
150
+ metric: str,
151
+ procedure: str | None = None,
152
+ ) -> tuple[float, float]:
153
+ """Compute 95% confidence interval (mean +/- 1.96*SE)."""
154
+ recs = self.filter(experiment=experiment, procedure=procedure)
155
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
156
+ if not vals:
157
+ return (float("nan"), float("nan"))
158
+ n = len(vals)
159
+ m = sum(vals) / n
160
+ if n < 2:
161
+ return (m, m)
162
+ var = sum((v - m) ** 2 for v in vals) / (n - 1)
163
+ se = math.sqrt(var / n)
164
+ return (m - 1.96 * se, m + 1.96 * se)
165
+
166
+ def improvement_over(
167
+ self,
168
+ baseline: str,
169
+ metric: str | None = None,
170
+ ) -> dict[str, dict[str, float]]:
171
+ """Compute relative improvement of all experiments over a baseline.
172
+
173
+ Returns:
174
+ {experiment: {metric: relative_improvement_pct}}
175
+ """
176
+ metrics = [metric] if metric else self.metric_names
177
+ result: dict[str, dict[str, float]] = {}
178
+
179
+ for exp in self.experiments:
180
+ if exp == baseline:
181
+ continue
182
+ improvements: dict[str, float] = {}
183
+ for m in metrics:
184
+ base_val = self.mean(baseline, m)
185
+ exp_val = self.mean(exp, m)
186
+ if math.isnan(base_val) or math.isnan(exp_val) or base_val == 0:
187
+ continue
188
+
189
+ higher_better = self.HIGHER_BETTER.get(m, True)
190
+ if higher_better:
191
+ pct = (exp_val - base_val) / abs(base_val) * 100
192
+ else:
193
+ pct = (base_val - exp_val) / abs(base_val) * 100
194
+ improvements[m] = round(pct, 2)
195
+
196
+ result[exp] = improvements
197
+
198
+ return result
199
+
200
+ def best_experiment(
201
+ self,
202
+ metric: str,
203
+ procedure: str | None = None,
204
+ ) -> str | None:
205
+ """Find the experiment with the best mean for a metric."""
206
+ higher_better = self.HIGHER_BETTER.get(metric, True)
207
+ best_exp = None
208
+ best_val = float("-inf") if higher_better else float("inf")
209
+
210
+ for exp in self.experiments:
211
+ val = self.mean(exp, metric, procedure)
212
+ if math.isnan(val):
213
+ continue
214
+ if higher_better and val > best_val:
215
+ best_val = val
216
+ best_exp = exp
217
+ elif not higher_better and val < best_val:
218
+ best_val = val
219
+ best_exp = exp
220
+
221
+ return best_exp
222
+
223
+ def summary_table(
224
+ self,
225
+ metrics: list[str] | None = None,
226
+ procedure: str | None = None,
227
+ include_std: bool = False,
228
+ ) -> str:
229
+ """Generate a text summary table.
230
+
231
+ Args:
232
+ metrics: Metrics to include. None = all.
233
+ procedure: Filter by procedure. None = aggregate.
234
+ include_std: Show mean +/- std.
235
+
236
+ Returns:
237
+ Formatted text table.
238
+ """
239
+ metrics = metrics or self.metric_names
240
+ exps = self.experiments
241
+
242
+ # Header
243
+ cols = ["Experiment"] + metrics
244
+ header = " | ".join(f"{c:>16s}" for c in cols)
245
+ lines = [header, "-" * len(header)]
246
+
247
+ for exp in exps:
248
+ parts = [f"{exp:>16s}"]
249
+ for m in metrics:
250
+ val = self.mean(exp, m, procedure)
251
+ if math.isnan(val):
252
+ parts.append(f"{'--':>16s}")
253
+ elif include_std:
254
+ s = self.std(exp, m, procedure)
255
+ parts.append(f"{val:>8.4f}±{s:<6.4f}")
256
+ else:
257
+ parts.append(f"{val:>16.4f}")
258
+ lines.append(" | ".join(parts))
259
+
260
+ return "\n".join(lines)
261
+
262
+ def to_json(self, path: str | Path | None = None) -> str:
263
+ """Export all records as JSON.
264
+
265
+ Args:
266
+ path: Optional file path to write to.
267
+
268
+ Returns:
269
+ JSON string.
270
+ """
271
+ data = {
272
+ "experiments": self.experiments,
273
+ "procedures": self.procedures,
274
+ "metrics": self.metric_names,
275
+ "records": [
276
+ {
277
+ "experiment": r.experiment,
278
+ "procedure": r.procedure,
279
+ "metrics": r.metrics,
280
+ "checkpoint_step": r.checkpoint_step,
281
+ "metadata": r.metadata,
282
+ }
283
+ for r in self.records
284
+ ],
285
+ }
286
+ j = json.dumps(data, indent=2)
287
+
288
+ if path is not None:
289
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
290
+ Path(path).write_text(j)
291
+
292
+ return j
293
+
294
+ @staticmethod
295
+ def from_json(path: str | Path) -> MetricsAggregator:
296
+ """Load aggregator from JSON."""
297
+ with open(path) as f:
298
+ data = json.load(f)
299
+
300
+ agg = MetricsAggregator()
301
+ for rec in data.get("records", []):
302
+ agg.add(
303
+ experiment=rec["experiment"],
304
+ procedure=rec["procedure"],
305
+ metrics=rec["metrics"],
306
+ checkpoint_step=rec.get("checkpoint_step"),
307
+ )
308
+ return agg