D-FINE / src /solver /validator.py
developer0hye's picture
Upload 76 files
e85fecb verified
import copy
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import torch
from loguru import logger
from torchvision.ops import box_iou
class Validator:
def __init__(
self,
gt: List[Dict[str, torch.Tensor]],
preds: List[Dict[str, torch.Tensor]],
conf_thresh=0.5,
iou_thresh=0.5,
) -> None:
"""
Format example:
gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...]
len(gt) is the number of images
bboxes are in format [x1, y1, x2, y2], absolute values
"""
self.gt = gt
self.preds = preds
self.conf_thresh = conf_thresh
self.iou_thresh = iou_thresh
self.thresholds = np.arange(0.2, 1.0, 0.05)
self.conf_matrix = None
def compute_metrics(self, extended=False) -> Dict[str, float]:
filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh)
metrics = self._compute_main_metrics(filtered_preds)
if not extended:
metrics.pop("extended_metrics", None)
return metrics
def _compute_main_metrics(self, preds):
(
self.metrics_per_class,
self.conf_matrix,
self.class_to_idx,
) = self._compute_metrics_and_confusion_matrix(preds)
tps, fps, fns = 0, 0, 0
ious = []
extended_metrics = {}
for key, value in self.metrics_per_class.items():
tps += value["TPs"]
fps += value["FPs"]
fns += value["FNs"]
ious.extend(value["IoUs"])
extended_metrics[f"precision_{key}"] = (
value["TPs"] / (value["TPs"] + value["FPs"])
if value["TPs"] + value["FPs"] > 0
else 0
)
extended_metrics[f"recall_{key}"] = (
value["TPs"] / (value["TPs"] + value["FNs"])
if value["TPs"] + value["FNs"] > 0
else 0
)
extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"])
precision = tps / (tps + fps) if (tps + fps) > 0 else 0
recall = tps / (tps + fns) if (tps + fns) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
iou = np.mean(ious).item() if ious else 0
return {
"f1": f1,
"precision": precision,
"recall": recall,
"iou": iou,
"TPs": tps,
"FPs": fps,
"FNs": fns,
"extended_metrics": extended_metrics,
}
def _compute_matrix_multi_class(self, preds):
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []})
for pred, gt in zip(preds, self.gt):
pred_boxes = pred["boxes"]
pred_labels = pred["labels"]
gt_boxes = gt["boxes"]
gt_labels = gt["labels"]
# isolate each class
labels = torch.unique(torch.cat([pred_labels, gt_labels]))
for label in labels:
pred_cl_boxes = pred_boxes[pred_labels == label] # filter by bool mask
gt_cl_boxes = gt_boxes[gt_labels == label]
n_preds = len(pred_cl_boxes)
n_gts = len(gt_cl_boxes)
if not (n_preds or n_gts):
continue
if not n_preds:
metrics_per_class[label.item()]["FNs"] += n_gts
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts)
continue
if not n_gts:
metrics_per_class[label.item()]["FPs"] += n_preds
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds)
continue
ious = box_iou(pred_cl_boxes, gt_cl_boxes) # matrix of all IoUs
ious_mask = ious >= self.iou_thresh
# indeces of boxes that have IoU >= threshold
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True)
if not pred_indices.numel(): # no predicts matched gts
metrics_per_class[label.item()]["FNs"] += n_gts
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts)
metrics_per_class[label.item()]["FPs"] += n_preds
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds)
continue
iou_values = ious[pred_indices, gt_indices]
# sorting by IoU to match hgihest scores first
sorted_indices = torch.argsort(-iou_values)
pred_indices = pred_indices[sorted_indices]
gt_indices = gt_indices[sorted_indices]
iou_values = iou_values[sorted_indices]
matched_preds = set()
matched_gts = set()
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values):
if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds:
matched_preds.add(pred_idx.item())
matched_gts.add(gt_idx.item())
metrics_per_class[label.item()]["TPs"] += 1
metrics_per_class[label.item()]["IoUs"].append(iou.item())
unmatched_preds = set(range(n_preds)) - matched_preds
unmatched_gts = set(range(n_gts)) - matched_gts
metrics_per_class[label.item()]["FPs"] += len(unmatched_preds)
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds))
metrics_per_class[label.item()]["FNs"] += len(unmatched_gts)
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts))
return metrics_per_class
def _compute_metrics_and_confusion_matrix(self, preds):
# Initialize per-class metrics
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []})
# Collect all class IDs
all_classes = set()
for pred in preds:
all_classes.update(pred["labels"].tolist())
for gt in self.gt:
all_classes.update(gt["labels"].tolist())
all_classes = sorted(list(all_classes))
class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)}
n_classes = len(all_classes)
conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int) # +1 for background class
for pred, gt in zip(preds, self.gt):
pred_boxes = pred["boxes"]
pred_labels = pred["labels"]
gt_boxes = gt["boxes"]
gt_labels = gt["labels"]
n_preds = len(pred_boxes)
n_gts = len(gt_boxes)
if n_preds == 0 and n_gts == 0:
continue
ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([])
# Assign matches between preds and gts
matched_pred_indices = set()
matched_gt_indices = set()
if ious.numel() > 0:
# For each pred box, find the gt box with highest IoU
ious_mask = ious >= self.iou_thresh
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True)
iou_values = ious[pred_indices, gt_indices]
# Sorting by IoU to match highest scores first
sorted_indices = torch.argsort(-iou_values)
pred_indices = pred_indices[sorted_indices]
gt_indices = gt_indices[sorted_indices]
iou_values = iou_values[sorted_indices]
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values):
if (
pred_idx.item() in matched_pred_indices
or gt_idx.item() in matched_gt_indices
):
continue
matched_pred_indices.add(pred_idx.item())
matched_gt_indices.add(gt_idx.item())
pred_label = pred_labels[pred_idx].item()
gt_label = gt_labels[gt_idx].item()
pred_cls_idx = class_to_idx[pred_label]
gt_cls_idx = class_to_idx[gt_label]
# Update confusion matrix
conf_matrix[gt_cls_idx, pred_cls_idx] += 1
# Update per-class metrics
if pred_label == gt_label:
metrics_per_class[gt_label]["TPs"] += 1
metrics_per_class[gt_label]["IoUs"].append(iou.item())
else:
# Misclassification
metrics_per_class[gt_label]["FNs"] += 1
metrics_per_class[pred_label]["FPs"] += 1
metrics_per_class[gt_label]["IoUs"].append(0)
metrics_per_class[pred_label]["IoUs"].append(0)
# Unmatched predictions (False Positives)
unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices
for pred_idx in unmatched_pred_indices:
pred_label = pred_labels[pred_idx].item()
pred_cls_idx = class_to_idx[pred_label]
# Update confusion matrix: background row
conf_matrix[n_classes, pred_cls_idx] += 1
# Update per-class metrics
metrics_per_class[pred_label]["FPs"] += 1
metrics_per_class[pred_label]["IoUs"].append(0)
# Unmatched ground truths (False Negatives)
unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices
for gt_idx in unmatched_gt_indices:
gt_label = gt_labels[gt_idx].item()
gt_cls_idx = class_to_idx[gt_label]
# Update confusion matrix: background column
conf_matrix[gt_cls_idx, n_classes] += 1
# Update per-class metrics
metrics_per_class[gt_label]["FNs"] += 1
metrics_per_class[gt_label]["IoUs"].append(0)
return metrics_per_class, conf_matrix, class_to_idx
def save_plots(self, path_to_save) -> None:
path_to_save = Path(path_to_save)
path_to_save.mkdir(parents=True, exist_ok=True)
if self.conf_matrix is not None:
class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"]
plt.figure(figsize=(10, 8))
plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(class_labels))
plt.xticks(tick_marks, class_labels, rotation=45)
plt.yticks(tick_marks, class_labels)
# Add labels to each cell
thresh = self.conf_matrix.max() / 2.0
for i in range(self.conf_matrix.shape[0]):
for j in range(self.conf_matrix.shape[1]):
plt.text(
j,
i,
format(self.conf_matrix[i, j], "d"),
horizontalalignment="center",
color="white" if self.conf_matrix[i, j] > thresh else "black",
)
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.savefig(path_to_save / "confusion_matrix.png")
plt.close()
thresholds = self.thresholds
precisions, recalls, f1_scores = [], [], []
# Store the original predictions to reset after each threshold
original_preds = copy.deepcopy(self.preds)
for threshold in thresholds:
# Filter predictions based on the current threshold
filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold)
# Compute metrics with the filtered predictions
metrics = self._compute_main_metrics(filtered_preds)
precisions.append(metrics["precision"])
recalls.append(metrics["recall"])
f1_scores.append(metrics["f1"])
# Plot Precision and Recall vs Threshold
plt.figure()
plt.plot(thresholds, precisions, label="Precision", marker="o")
plt.plot(thresholds, recalls, label="Recall", marker="o")
plt.xlabel("Threshold")
plt.ylabel("Value")
plt.title("Precision and Recall vs Threshold")
plt.legend()
plt.grid(True)
plt.savefig(path_to_save / "precision_recall_vs_threshold.png")
plt.close()
# Plot F1 Score vs Threshold
plt.figure()
plt.plot(thresholds, f1_scores, label="F1 Score", marker="o")
plt.xlabel("Threshold")
plt.ylabel("F1 Score")
plt.title("F1 Score vs Threshold")
plt.grid(True)
plt.savefig(path_to_save / "f1_score_vs_threshold.png")
plt.close()
# Find the best threshold based on F1 Score (last occurence)
best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1
best_threshold = thresholds[best_idx]
best_f1 = f1_scores[best_idx]
logger.info(
f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}"
)
def filter_preds(preds, conf_thresh):
for pred in preds:
keep_idxs = pred["scores"] >= conf_thresh
pred["scores"] = pred["scores"][keep_idxs]
pred["boxes"] = pred["boxes"][keep_idxs]
pred["labels"] = pred["labels"][keep_idxs]
return preds
def scale_boxes(boxes, orig_shape, resized_shape):
"""
boxes in format: [x1, y1, x2, y2], absolute values
orig_shape: [height, width]
resized_shape: [height, width]
"""
scale_x = orig_shape[1] / resized_shape[1]
scale_y = orig_shape[0] / resized_shape[0]
boxes[:, 0] *= scale_x
boxes[:, 2] *= scale_x
boxes[:, 1] *= scale_y
boxes[:, 3] *= scale_y
return boxes