Spaces:
Running
Running
Update landmarkdiff/metrics_agg.py to v0.3.2
Browse files- landmarkdiff/metrics_agg.py +15 -19
landmarkdiff/metrics_agg.py
CHANGED
|
@@ -41,12 +41,8 @@ class MetricsAggregator:
|
|
| 41 |
"""
|
| 42 |
|
| 43 |
HIGHER_BETTER = {
|
| 44 |
-
"ssim": True,
|
| 45 |
-
"
|
| 46 |
-
"identity_sim": True,
|
| 47 |
-
"lpips": False,
|
| 48 |
-
"fid": False,
|
| 49 |
-
"nme": False,
|
| 50 |
}
|
| 51 |
|
| 52 |
def __init__(self) -> None:
|
|
@@ -61,15 +57,13 @@ class MetricsAggregator:
|
|
| 61 |
**metadata: Any,
|
| 62 |
) -> None:
|
| 63 |
"""Add a single evaluation record."""
|
| 64 |
-
self.records.append(
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
)
|
| 72 |
-
)
|
| 73 |
|
| 74 |
def add_batch(
|
| 75 |
self,
|
|
@@ -82,9 +76,7 @@ class MetricsAggregator:
|
|
| 82 |
"""
|
| 83 |
for rec in records:
|
| 84 |
proc = rec.get("procedure", "all")
|
| 85 |
-
metrics = {
|
| 86 |
-
k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
|
| 87 |
-
}
|
| 88 |
self.add(experiment, proc, metrics)
|
| 89 |
|
| 90 |
@property
|
|
@@ -219,7 +211,10 @@ class MetricsAggregator:
|
|
| 219 |
val = self.mean(exp, metric, procedure)
|
| 220 |
if math.isnan(val):
|
| 221 |
continue
|
| 222 |
-
if
|
|
|
|
|
|
|
|
|
|
| 223 |
best_val = val
|
| 224 |
best_exp = exp
|
| 225 |
|
|
@@ -309,5 +304,6 @@ class MetricsAggregator:
|
|
| 309 |
procedure=rec["procedure"],
|
| 310 |
metrics=rec["metrics"],
|
| 311 |
checkpoint_step=rec.get("checkpoint_step"),
|
|
|
|
| 312 |
)
|
| 313 |
return agg
|
|
|
|
| 41 |
"""
|
| 42 |
|
| 43 |
HIGHER_BETTER = {
|
| 44 |
+
"ssim": True, "psnr": True, "identity_sim": True,
|
| 45 |
+
"lpips": False, "fid": False, "nme": False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
def __init__(self) -> None:
|
|
|
|
| 57 |
**metadata: Any,
|
| 58 |
) -> None:
|
| 59 |
"""Add a single evaluation record."""
|
| 60 |
+
self.records.append(MetricRecord(
|
| 61 |
+
experiment=experiment,
|
| 62 |
+
procedure=procedure,
|
| 63 |
+
metrics=metrics,
|
| 64 |
+
checkpoint_step=checkpoint_step,
|
| 65 |
+
metadata=metadata,
|
| 66 |
+
))
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def add_batch(
|
| 69 |
self,
|
|
|
|
| 76 |
"""
|
| 77 |
for rec in records:
|
| 78 |
proc = rec.get("procedure", "all")
|
| 79 |
+
metrics = {k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))}
|
|
|
|
|
|
|
| 80 |
self.add(experiment, proc, metrics)
|
| 81 |
|
| 82 |
@property
|
|
|
|
| 211 |
val = self.mean(exp, metric, procedure)
|
| 212 |
if math.isnan(val):
|
| 213 |
continue
|
| 214 |
+
if higher_better and val > best_val:
|
| 215 |
+
best_val = val
|
| 216 |
+
best_exp = exp
|
| 217 |
+
elif not higher_better and val < best_val:
|
| 218 |
best_val = val
|
| 219 |
best_exp = exp
|
| 220 |
|
|
|
|
| 304 |
procedure=rec["procedure"],
|
| 305 |
metrics=rec["metrics"],
|
| 306 |
checkpoint_step=rec.get("checkpoint_step"),
|
| 307 |
+
**rec.get("metadata", {}),
|
| 308 |
)
|
| 309 |
return agg
|