MidasMap / src /evaluate.py
AnikS22's picture
Deploy MidasMap Gradio app, src, requirements, checkpoint
86c24cb verified
"""
Evaluation: Hungarian matching, per-class metrics, LOOCV runner.
Uses scipy linear_sum_assignment for optimal bipartite matching between
predictions and ground truth with class-specific match radii.
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from typing import Dict, List, Optional, Tuple
def compute_f1(tp: int, fp: int, fn: int, eps: float = 1e-6) -> Tuple[float, float, float]:
"""Compute F1, precision, recall from TP/FP/FN counts."""
precision = tp / (tp + fp + eps)
recall = tp / (tp + fn + eps)
f1 = 2 * precision * recall / (precision + recall + eps)
return f1, precision, recall
def match_detections_to_gt(
detections: List[dict],
gt_coords_6nm: np.ndarray,
gt_coords_12nm: np.ndarray,
match_radii: Optional[Dict[str, float]] = None,
) -> Dict[str, dict]:
"""
Hungarian matching between predictions and ground truth.
A detection matches GT only if:
1. Euclidean distance < match_radius[class]
2. Predicted class == GT class
Args:
detections: list of {'x', 'y', 'class', 'conf'}
gt_coords_6nm: (N, 2) array of (x, y) GT for 6nm
gt_coords_12nm: (M, 2) array of (x, y) GT for 12nm
match_radii: per-class match radius in pixels
Returns:
Dict with per-class and overall TP/FP/FN/F1/precision/recall.
"""
if match_radii is None:
match_radii = {"6nm": 9.0, "12nm": 15.0}
gt_by_class = {"6nm": gt_coords_6nm, "12nm": gt_coords_12nm}
results = {}
total_tp = 0
total_fp = 0
total_fn = 0
for cls in ["6nm", "12nm"]:
cls_dets = [d for d in detections if d["class"] == cls]
gt = gt_by_class[cls]
radius = match_radii[cls]
if len(cls_dets) == 0 and len(gt) == 0:
results[cls] = {
"tp": 0, "fp": 0, "fn": 0,
"f1": 1.0, "precision": 1.0, "recall": 1.0,
}
continue
if len(cls_dets) == 0:
results[cls] = {
"tp": 0, "fp": 0, "fn": len(gt),
"f1": 0.0, "precision": 0.0, "recall": 0.0,
}
total_fn += len(gt)
continue
if len(gt) == 0:
results[cls] = {
"tp": 0, "fp": len(cls_dets), "fn": 0,
"f1": 0.0, "precision": 0.0, "recall": 0.0,
}
total_fp += len(cls_dets)
continue
# Build cost matrix
pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets])
cost = cdist(pred_coords, gt)
# Set costs beyond radius to a large value (forbid match)
cost_masked = cost.copy()
cost_masked[cost_masked > radius] = 1e6
# Hungarian matching
row_ind, col_ind = linear_sum_assignment(cost_masked)
# Count valid matches (within radius)
tp = sum(
1 for r, c in zip(row_ind, col_ind)
if cost_masked[r, c] <= radius
)
fp = len(cls_dets) - tp
fn = len(gt) - tp
f1, prec, rec = compute_f1(tp, fp, fn)
results[cls] = {
"tp": tp, "fp": fp, "fn": fn,
"f1": f1, "precision": prec, "recall": rec,
}
total_tp += tp
total_fp += fp
total_fn += fn
# Overall
f1_overall, prec_overall, rec_overall = compute_f1(total_tp, total_fp, total_fn)
results["overall"] = {
"tp": total_tp, "fp": total_fp, "fn": total_fn,
"f1": f1_overall, "precision": prec_overall, "recall": rec_overall,
}
# Mean F1 across classes
class_f1s = [results[c]["f1"] for c in ["6nm", "12nm"] if results[c]["fn"] + results[c]["tp"] > 0]
results["mean_f1"] = np.mean(class_f1s) if class_f1s else 0.0
return results
def evaluate_fold(
detections: List[dict],
gt_annotations: Dict[str, np.ndarray],
match_radii: Optional[Dict[str, float]] = None,
has_6nm: bool = True,
) -> Dict[str, dict]:
"""
Evaluate detections for a single LOOCV fold.
Args:
detections: model predictions
gt_annotations: {'6nm': Nx2, '12nm': Mx2}
match_radii: per-class match radii
has_6nm: whether this fold has 6nm GT (False for S7, S15)
Returns:
Evaluation metrics dict.
"""
gt_6nm = gt_annotations.get("6nm", np.empty((0, 2)))
gt_12nm = gt_annotations.get("12nm", np.empty((0, 2)))
results = match_detections_to_gt(detections, gt_6nm, gt_12nm, match_radii)
if not has_6nm:
results["6nm"]["note"] = "N/A (missing annotations)"
return results
def compute_average_precision(
detections: List[dict],
gt_coords: np.ndarray,
match_radius: float,
) -> float:
"""
Compute Average Precision (AP) for a single class.
Follows PASCAL VOC style: sort by confidence, compute precision-recall
curve, then compute area under curve.
"""
if len(gt_coords) == 0:
return 0.0 if detections else 1.0
# Sort by confidence descending
sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
tp_list = []
fp_list = []
matched_gt = set()
for det in sorted_dets:
det_coord = np.array([det["x"], det["y"]])
dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
min_idx = np.argmin(dists)
if dists[min_idx] <= match_radius and min_idx not in matched_gt:
tp_list.append(1)
fp_list.append(0)
matched_gt.add(min_idx)
else:
tp_list.append(0)
fp_list.append(1)
tp_cumsum = np.cumsum(tp_list)
fp_cumsum = np.cumsum(fp_list)
precision = tp_cumsum / (tp_cumsum + fp_cumsum)
recall = tp_cumsum / len(gt_coords)
# Compute AP using all-point interpolation
ap = 0.0
for i in range(len(precision)):
if i == 0:
ap += precision[i] * recall[i]
else:
ap += precision[i] * (recall[i] - recall[i - 1])
return ap