Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Dict, List | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from loguru import logger | |
from torchvision.ops import box_iou | |
class Validator: | |
def __init__( | |
self, | |
gt: List[Dict[str, torch.Tensor]], | |
preds: List[Dict[str, torch.Tensor]], | |
conf_thresh=0.5, | |
iou_thresh=0.5, | |
) -> None: | |
""" | |
Format example: | |
gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...] | |
len(gt) is the number of images | |
bboxes are in format [x1, y1, x2, y2], absolute values | |
""" | |
self.gt = gt | |
self.preds = preds | |
self.conf_thresh = conf_thresh | |
self.iou_thresh = iou_thresh | |
self.thresholds = np.arange(0.2, 1.0, 0.05) | |
self.conf_matrix = None | |
def compute_metrics(self, extended=False) -> Dict[str, float]: | |
filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh) | |
metrics = self._compute_main_metrics(filtered_preds) | |
if not extended: | |
metrics.pop("extended_metrics", None) | |
return metrics | |
def _compute_main_metrics(self, preds): | |
( | |
self.metrics_per_class, | |
self.conf_matrix, | |
self.class_to_idx, | |
) = self._compute_metrics_and_confusion_matrix(preds) | |
tps, fps, fns = 0, 0, 0 | |
ious = [] | |
extended_metrics = {} | |
for key, value in self.metrics_per_class.items(): | |
tps += value["TPs"] | |
fps += value["FPs"] | |
fns += value["FNs"] | |
ious.extend(value["IoUs"]) | |
extended_metrics[f"precision_{key}"] = ( | |
value["TPs"] / (value["TPs"] + value["FPs"]) | |
if value["TPs"] + value["FPs"] > 0 | |
else 0 | |
) | |
extended_metrics[f"recall_{key}"] = ( | |
value["TPs"] / (value["TPs"] + value["FNs"]) | |
if value["TPs"] + value["FNs"] > 0 | |
else 0 | |
) | |
extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"]) | |
precision = tps / (tps + fps) if (tps + fps) > 0 else 0 | |
recall = tps / (tps + fns) if (tps + fns) > 0 else 0 | |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
iou = np.mean(ious).item() if ious else 0 | |
return { | |
"f1": f1, | |
"precision": precision, | |
"recall": recall, | |
"iou": iou, | |
"TPs": tps, | |
"FPs": fps, | |
"FNs": fns, | |
"extended_metrics": extended_metrics, | |
} | |
def _compute_matrix_multi_class(self, preds): | |
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) | |
for pred, gt in zip(preds, self.gt): | |
pred_boxes = pred["boxes"] | |
pred_labels = pred["labels"] | |
gt_boxes = gt["boxes"] | |
gt_labels = gt["labels"] | |
# isolate each class | |
labels = torch.unique(torch.cat([pred_labels, gt_labels])) | |
for label in labels: | |
pred_cl_boxes = pred_boxes[pred_labels == label] # filter by bool mask | |
gt_cl_boxes = gt_boxes[gt_labels == label] | |
n_preds = len(pred_cl_boxes) | |
n_gts = len(gt_cl_boxes) | |
if not (n_preds or n_gts): | |
continue | |
if not n_preds: | |
metrics_per_class[label.item()]["FNs"] += n_gts | |
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) | |
continue | |
if not n_gts: | |
metrics_per_class[label.item()]["FPs"] += n_preds | |
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) | |
continue | |
ious = box_iou(pred_cl_boxes, gt_cl_boxes) # matrix of all IoUs | |
ious_mask = ious >= self.iou_thresh | |
# indeces of boxes that have IoU >= threshold | |
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) | |
if not pred_indices.numel(): # no predicts matched gts | |
metrics_per_class[label.item()]["FNs"] += n_gts | |
metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) | |
metrics_per_class[label.item()]["FPs"] += n_preds | |
metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) | |
continue | |
iou_values = ious[pred_indices, gt_indices] | |
# sorting by IoU to match hgihest scores first | |
sorted_indices = torch.argsort(-iou_values) | |
pred_indices = pred_indices[sorted_indices] | |
gt_indices = gt_indices[sorted_indices] | |
iou_values = iou_values[sorted_indices] | |
matched_preds = set() | |
matched_gts = set() | |
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): | |
if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds: | |
matched_preds.add(pred_idx.item()) | |
matched_gts.add(gt_idx.item()) | |
metrics_per_class[label.item()]["TPs"] += 1 | |
metrics_per_class[label.item()]["IoUs"].append(iou.item()) | |
unmatched_preds = set(range(n_preds)) - matched_preds | |
unmatched_gts = set(range(n_gts)) - matched_gts | |
metrics_per_class[label.item()]["FPs"] += len(unmatched_preds) | |
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds)) | |
metrics_per_class[label.item()]["FNs"] += len(unmatched_gts) | |
metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts)) | |
return metrics_per_class | |
def _compute_metrics_and_confusion_matrix(self, preds): | |
# Initialize per-class metrics | |
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) | |
# Collect all class IDs | |
all_classes = set() | |
for pred in preds: | |
all_classes.update(pred["labels"].tolist()) | |
for gt in self.gt: | |
all_classes.update(gt["labels"].tolist()) | |
all_classes = sorted(list(all_classes)) | |
class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)} | |
n_classes = len(all_classes) | |
conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int) # +1 for background class | |
for pred, gt in zip(preds, self.gt): | |
pred_boxes = pred["boxes"] | |
pred_labels = pred["labels"] | |
gt_boxes = gt["boxes"] | |
gt_labels = gt["labels"] | |
n_preds = len(pred_boxes) | |
n_gts = len(gt_boxes) | |
if n_preds == 0 and n_gts == 0: | |
continue | |
ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([]) | |
# Assign matches between preds and gts | |
matched_pred_indices = set() | |
matched_gt_indices = set() | |
if ious.numel() > 0: | |
# For each pred box, find the gt box with highest IoU | |
ious_mask = ious >= self.iou_thresh | |
pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) | |
iou_values = ious[pred_indices, gt_indices] | |
# Sorting by IoU to match highest scores first | |
sorted_indices = torch.argsort(-iou_values) | |
pred_indices = pred_indices[sorted_indices] | |
gt_indices = gt_indices[sorted_indices] | |
iou_values = iou_values[sorted_indices] | |
for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): | |
if ( | |
pred_idx.item() in matched_pred_indices | |
or gt_idx.item() in matched_gt_indices | |
): | |
continue | |
matched_pred_indices.add(pred_idx.item()) | |
matched_gt_indices.add(gt_idx.item()) | |
pred_label = pred_labels[pred_idx].item() | |
gt_label = gt_labels[gt_idx].item() | |
pred_cls_idx = class_to_idx[pred_label] | |
gt_cls_idx = class_to_idx[gt_label] | |
# Update confusion matrix | |
conf_matrix[gt_cls_idx, pred_cls_idx] += 1 | |
# Update per-class metrics | |
if pred_label == gt_label: | |
metrics_per_class[gt_label]["TPs"] += 1 | |
metrics_per_class[gt_label]["IoUs"].append(iou.item()) | |
else: | |
# Misclassification | |
metrics_per_class[gt_label]["FNs"] += 1 | |
metrics_per_class[pred_label]["FPs"] += 1 | |
metrics_per_class[gt_label]["IoUs"].append(0) | |
metrics_per_class[pred_label]["IoUs"].append(0) | |
# Unmatched predictions (False Positives) | |
unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices | |
for pred_idx in unmatched_pred_indices: | |
pred_label = pred_labels[pred_idx].item() | |
pred_cls_idx = class_to_idx[pred_label] | |
# Update confusion matrix: background row | |
conf_matrix[n_classes, pred_cls_idx] += 1 | |
# Update per-class metrics | |
metrics_per_class[pred_label]["FPs"] += 1 | |
metrics_per_class[pred_label]["IoUs"].append(0) | |
# Unmatched ground truths (False Negatives) | |
unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices | |
for gt_idx in unmatched_gt_indices: | |
gt_label = gt_labels[gt_idx].item() | |
gt_cls_idx = class_to_idx[gt_label] | |
# Update confusion matrix: background column | |
conf_matrix[gt_cls_idx, n_classes] += 1 | |
# Update per-class metrics | |
metrics_per_class[gt_label]["FNs"] += 1 | |
metrics_per_class[gt_label]["IoUs"].append(0) | |
return metrics_per_class, conf_matrix, class_to_idx | |
def save_plots(self, path_to_save) -> None: | |
path_to_save = Path(path_to_save) | |
path_to_save.mkdir(parents=True, exist_ok=True) | |
if self.conf_matrix is not None: | |
class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"] | |
plt.figure(figsize=(10, 8)) | |
plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues) | |
plt.title("Confusion Matrix") | |
plt.colorbar() | |
tick_marks = np.arange(len(class_labels)) | |
plt.xticks(tick_marks, class_labels, rotation=45) | |
plt.yticks(tick_marks, class_labels) | |
# Add labels to each cell | |
thresh = self.conf_matrix.max() / 2.0 | |
for i in range(self.conf_matrix.shape[0]): | |
for j in range(self.conf_matrix.shape[1]): | |
plt.text( | |
j, | |
i, | |
format(self.conf_matrix[i, j], "d"), | |
horizontalalignment="center", | |
color="white" if self.conf_matrix[i, j] > thresh else "black", | |
) | |
plt.ylabel("True label") | |
plt.xlabel("Predicted label") | |
plt.tight_layout() | |
plt.savefig(path_to_save / "confusion_matrix.png") | |
plt.close() | |
thresholds = self.thresholds | |
precisions, recalls, f1_scores = [], [], [] | |
# Store the original predictions to reset after each threshold | |
original_preds = copy.deepcopy(self.preds) | |
for threshold in thresholds: | |
# Filter predictions based on the current threshold | |
filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold) | |
# Compute metrics with the filtered predictions | |
metrics = self._compute_main_metrics(filtered_preds) | |
precisions.append(metrics["precision"]) | |
recalls.append(metrics["recall"]) | |
f1_scores.append(metrics["f1"]) | |
# Plot Precision and Recall vs Threshold | |
plt.figure() | |
plt.plot(thresholds, precisions, label="Precision", marker="o") | |
plt.plot(thresholds, recalls, label="Recall", marker="o") | |
plt.xlabel("Threshold") | |
plt.ylabel("Value") | |
plt.title("Precision and Recall vs Threshold") | |
plt.legend() | |
plt.grid(True) | |
plt.savefig(path_to_save / "precision_recall_vs_threshold.png") | |
plt.close() | |
# Plot F1 Score vs Threshold | |
plt.figure() | |
plt.plot(thresholds, f1_scores, label="F1 Score", marker="o") | |
plt.xlabel("Threshold") | |
plt.ylabel("F1 Score") | |
plt.title("F1 Score vs Threshold") | |
plt.grid(True) | |
plt.savefig(path_to_save / "f1_score_vs_threshold.png") | |
plt.close() | |
# Find the best threshold based on F1 Score (last occurence) | |
best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1 | |
best_threshold = thresholds[best_idx] | |
best_f1 = f1_scores[best_idx] | |
logger.info( | |
f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}" | |
) | |
def filter_preds(preds, conf_thresh): | |
for pred in preds: | |
keep_idxs = pred["scores"] >= conf_thresh | |
pred["scores"] = pred["scores"][keep_idxs] | |
pred["boxes"] = pred["boxes"][keep_idxs] | |
pred["labels"] = pred["labels"][keep_idxs] | |
return preds | |
def scale_boxes(boxes, orig_shape, resized_shape): | |
""" | |
boxes in format: [x1, y1, x2, y2], absolute values | |
orig_shape: [height, width] | |
resized_shape: [height, width] | |
""" | |
scale_x = orig_shape[1] / resized_shape[1] | |
scale_y = orig_shape[0] / resized_shape[0] | |
boxes[:, 0] *= scale_x | |
boxes[:, 2] *= scale_x | |
boxes[:, 1] *= scale_y | |
boxes[:, 3] *= scale_y | |
return boxes | |