Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from loguru import logger | |
| from torchvision.ops import box_iou | |
| class Validator: | |
| def __init__( | |
| self, | |
| gt: List[Dict[str, torch.Tensor]], | |
| preds: List[Dict[str, torch.Tensor]], | |
| conf_thresh=0.5, | |
| iou_thresh=0.5, | |
| ) -> None: | |
| """ | |
| Format example: | |
| gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...] | |
| len(gt) is the number of images | |
| bboxes are in format [x1, y1, x2, y2], absolute values | |
| """ | |
| self.gt = gt | |
| self.preds = preds | |
| self.conf_thresh = conf_thresh | |
| self.iou_thresh = iou_thresh | |
| self.thresholds = np.arange(0.2, 1.0, 0.05) | |
| self.conf_matrix = None | |
| def compute_metrics(self, extended=False) -> Dict[str, float]: | |
| filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh) | |
| metrics = self._compute_main_metrics(filtered_preds) | |
| if not extended: | |
| metrics.pop("extended_metrics", None) | |
| return metrics | |
| def _compute_main_metrics(self, preds): | |
| ( | |
| self.metrics_per_class, | |
| self.conf_matrix, | |
| self.class_to_idx, | |
| ) = self._compute_metrics_and_confusion_matrix(preds) | |
| tps, fps, fns = 0, 0, 0 | |
| ious = [] | |
| extended_metrics = {} | |
| for key, value in self.metrics_per_class.items(): | |
| tps += value["TPs"] | |
| fps += value["FPs"] | |
| fns += value["FNs"] | |
| ious.extend(value["IoUs"]) | |
| extended_metrics[f"precision_{key}"] = ( | |
| value["TPs"] / (value["TPs"] + value["FPs"]) | |
| if value["TPs"] + value["FPs"] > 0 | |
| else 0 | |
| ) | |
| extended_metrics[f"recall_{key}"] = ( | |
| value["TPs"] / (value["TPs"] + value["FNs"]) | |
| if value["TPs"] + value["FNs"] > 0 | |
| else 0 | |
| ) | |
| extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"]) | |
| precision = tps / (tps + fps) if (tps + fps) > 0 else 0 | |
| recall = tps / (tps + fns) if (tps + fns) > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| iou = np.mean(ious).item() if ious else 0 | |
| return { | |
| "f1": f1, | |
| "precision": precision, | |
| "recall": recall, | |
| "iou": iou, | |
| "TPs": tps, | |
| "FPs": fps, | |
| "FNs": fns, | |
| "extended_metrics": extended_metrics, | |
| } | |
| def _compute_matrix_multi_class(self, preds): | |
| metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) | |
| for pred, gt in zip(preds, self.gt): | |
| pred_boxes = pred["boxes"] | |
| pred_labels = pred["labels"] | |
| gt_boxes = gt["boxes"] | |
| gt_labels = gt["labels"] | |
| # isolate each class | |
| labels = torch.unique(torch.cat([pred_labels, gt_labels])) | |
| for label in labels: | |
| pred_cl_boxes = pred_boxes[pred_labels == label] # filter by bool mask | |
| gt_cl_boxes = gt_boxes[gt_labels == label] | |
| n_preds = len(pred_cl_boxes) | |
| n_gts = len(gt_cl_boxes) | |
| if not (n_preds or n_gts): | |
| continue | |
| if not n_preds: | |
| metrics_per_class[label.item()]["FNs"] += n_gts | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) | |
| continue | |
| if not n_gts: | |
| metrics_per_class[label.item()]["FPs"] += n_preds | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) | |
| continue | |
| ious = box_iou(pred_cl_boxes, gt_cl_boxes) # matrix of all IoUs | |
| ious_mask = ious >= self.iou_thresh | |
| # indeces of boxes that have IoU >= threshold | |
| pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) | |
| if not pred_indices.numel(): # no predicts matched gts | |
| metrics_per_class[label.item()]["FNs"] += n_gts | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) | |
| metrics_per_class[label.item()]["FPs"] += n_preds | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) | |
| continue | |
| iou_values = ious[pred_indices, gt_indices] | |
| # sorting by IoU to match hgihest scores first | |
| sorted_indices = torch.argsort(-iou_values) | |
| pred_indices = pred_indices[sorted_indices] | |
| gt_indices = gt_indices[sorted_indices] | |
| iou_values = iou_values[sorted_indices] | |
| matched_preds = set() | |
| matched_gts = set() | |
| for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): | |
| if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds: | |
| matched_preds.add(pred_idx.item()) | |
| matched_gts.add(gt_idx.item()) | |
| metrics_per_class[label.item()]["TPs"] += 1 | |
| metrics_per_class[label.item()]["IoUs"].append(iou.item()) | |
| unmatched_preds = set(range(n_preds)) - matched_preds | |
| unmatched_gts = set(range(n_gts)) - matched_gts | |
| metrics_per_class[label.item()]["FPs"] += len(unmatched_preds) | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds)) | |
| metrics_per_class[label.item()]["FNs"] += len(unmatched_gts) | |
| metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts)) | |
| return metrics_per_class | |
| def _compute_metrics_and_confusion_matrix(self, preds): | |
| # Initialize per-class metrics | |
| metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) | |
| # Collect all class IDs | |
| all_classes = set() | |
| for pred in preds: | |
| all_classes.update(pred["labels"].tolist()) | |
| for gt in self.gt: | |
| all_classes.update(gt["labels"].tolist()) | |
| all_classes = sorted(list(all_classes)) | |
| class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)} | |
| n_classes = len(all_classes) | |
| conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int) # +1 for background class | |
| for pred, gt in zip(preds, self.gt): | |
| pred_boxes = pred["boxes"] | |
| pred_labels = pred["labels"] | |
| gt_boxes = gt["boxes"] | |
| gt_labels = gt["labels"] | |
| n_preds = len(pred_boxes) | |
| n_gts = len(gt_boxes) | |
| if n_preds == 0 and n_gts == 0: | |
| continue | |
| ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([]) | |
| # Assign matches between preds and gts | |
| matched_pred_indices = set() | |
| matched_gt_indices = set() | |
| if ious.numel() > 0: | |
| # For each pred box, find the gt box with highest IoU | |
| ious_mask = ious >= self.iou_thresh | |
| pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) | |
| iou_values = ious[pred_indices, gt_indices] | |
| # Sorting by IoU to match highest scores first | |
| sorted_indices = torch.argsort(-iou_values) | |
| pred_indices = pred_indices[sorted_indices] | |
| gt_indices = gt_indices[sorted_indices] | |
| iou_values = iou_values[sorted_indices] | |
| for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): | |
| if ( | |
| pred_idx.item() in matched_pred_indices | |
| or gt_idx.item() in matched_gt_indices | |
| ): | |
| continue | |
| matched_pred_indices.add(pred_idx.item()) | |
| matched_gt_indices.add(gt_idx.item()) | |
| pred_label = pred_labels[pred_idx].item() | |
| gt_label = gt_labels[gt_idx].item() | |
| pred_cls_idx = class_to_idx[pred_label] | |
| gt_cls_idx = class_to_idx[gt_label] | |
| # Update confusion matrix | |
| conf_matrix[gt_cls_idx, pred_cls_idx] += 1 | |
| # Update per-class metrics | |
| if pred_label == gt_label: | |
| metrics_per_class[gt_label]["TPs"] += 1 | |
| metrics_per_class[gt_label]["IoUs"].append(iou.item()) | |
| else: | |
| # Misclassification | |
| metrics_per_class[gt_label]["FNs"] += 1 | |
| metrics_per_class[pred_label]["FPs"] += 1 | |
| metrics_per_class[gt_label]["IoUs"].append(0) | |
| metrics_per_class[pred_label]["IoUs"].append(0) | |
| # Unmatched predictions (False Positives) | |
| unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices | |
| for pred_idx in unmatched_pred_indices: | |
| pred_label = pred_labels[pred_idx].item() | |
| pred_cls_idx = class_to_idx[pred_label] | |
| # Update confusion matrix: background row | |
| conf_matrix[n_classes, pred_cls_idx] += 1 | |
| # Update per-class metrics | |
| metrics_per_class[pred_label]["FPs"] += 1 | |
| metrics_per_class[pred_label]["IoUs"].append(0) | |
| # Unmatched ground truths (False Negatives) | |
| unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices | |
| for gt_idx in unmatched_gt_indices: | |
| gt_label = gt_labels[gt_idx].item() | |
| gt_cls_idx = class_to_idx[gt_label] | |
| # Update confusion matrix: background column | |
| conf_matrix[gt_cls_idx, n_classes] += 1 | |
| # Update per-class metrics | |
| metrics_per_class[gt_label]["FNs"] += 1 | |
| metrics_per_class[gt_label]["IoUs"].append(0) | |
| return metrics_per_class, conf_matrix, class_to_idx | |
| def save_plots(self, path_to_save) -> None: | |
| path_to_save = Path(path_to_save) | |
| path_to_save.mkdir(parents=True, exist_ok=True) | |
| if self.conf_matrix is not None: | |
| class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"] | |
| plt.figure(figsize=(10, 8)) | |
| plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues) | |
| plt.title("Confusion Matrix") | |
| plt.colorbar() | |
| tick_marks = np.arange(len(class_labels)) | |
| plt.xticks(tick_marks, class_labels, rotation=45) | |
| plt.yticks(tick_marks, class_labels) | |
| # Add labels to each cell | |
| thresh = self.conf_matrix.max() / 2.0 | |
| for i in range(self.conf_matrix.shape[0]): | |
| for j in range(self.conf_matrix.shape[1]): | |
| plt.text( | |
| j, | |
| i, | |
| format(self.conf_matrix[i, j], "d"), | |
| horizontalalignment="center", | |
| color="white" if self.conf_matrix[i, j] > thresh else "black", | |
| ) | |
| plt.ylabel("True label") | |
| plt.xlabel("Predicted label") | |
| plt.tight_layout() | |
| plt.savefig(path_to_save / "confusion_matrix.png") | |
| plt.close() | |
| thresholds = self.thresholds | |
| precisions, recalls, f1_scores = [], [], [] | |
| # Store the original predictions to reset after each threshold | |
| original_preds = copy.deepcopy(self.preds) | |
| for threshold in thresholds: | |
| # Filter predictions based on the current threshold | |
| filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold) | |
| # Compute metrics with the filtered predictions | |
| metrics = self._compute_main_metrics(filtered_preds) | |
| precisions.append(metrics["precision"]) | |
| recalls.append(metrics["recall"]) | |
| f1_scores.append(metrics["f1"]) | |
| # Plot Precision and Recall vs Threshold | |
| plt.figure() | |
| plt.plot(thresholds, precisions, label="Precision", marker="o") | |
| plt.plot(thresholds, recalls, label="Recall", marker="o") | |
| plt.xlabel("Threshold") | |
| plt.ylabel("Value") | |
| plt.title("Precision and Recall vs Threshold") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.savefig(path_to_save / "precision_recall_vs_threshold.png") | |
| plt.close() | |
| # Plot F1 Score vs Threshold | |
| plt.figure() | |
| plt.plot(thresholds, f1_scores, label="F1 Score", marker="o") | |
| plt.xlabel("Threshold") | |
| plt.ylabel("F1 Score") | |
| plt.title("F1 Score vs Threshold") | |
| plt.grid(True) | |
| plt.savefig(path_to_save / "f1_score_vs_threshold.png") | |
| plt.close() | |
| # Find the best threshold based on F1 Score (last occurence) | |
| best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1 | |
| best_threshold = thresholds[best_idx] | |
| best_f1 = f1_scores[best_idx] | |
| logger.info( | |
| f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}" | |
| ) | |
| def filter_preds(preds, conf_thresh): | |
| for pred in preds: | |
| keep_idxs = pred["scores"] >= conf_thresh | |
| pred["scores"] = pred["scores"][keep_idxs] | |
| pred["boxes"] = pred["boxes"][keep_idxs] | |
| pred["labels"] = pred["labels"][keep_idxs] | |
| return preds | |
| def scale_boxes(boxes, orig_shape, resized_shape): | |
| """ | |
| boxes in format: [x1, y1, x2, y2], absolute values | |
| orig_shape: [height, width] | |
| resized_shape: [height, width] | |
| """ | |
| scale_x = orig_shape[1] / resized_shape[1] | |
| scale_y = orig_shape[0] / resized_shape[0] | |
| boxes[:, 0] *= scale_x | |
| boxes[:, 2] *= scale_x | |
| boxes[:, 1] *= scale_y | |
| boxes[:, 3] *= scale_y | |
| return boxes | |