import cv2 import numpy as np import torch from mmdet.structures import DetDataSample from mmengine.structures import InstanceData class FilterSegMask: def __init__(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Removes smaller masks that are contained in a bigger mask # @timer_func def remove_overlapping_masks(self, predicted_mask, method="mask", containments_threshold=0.5): # Convert masks to binary images masks = [mask.cpu().numpy() for mask in predicted_mask.pred_instances.masks] masks_binary = [(mask > 0).astype(np.uint8) for mask in masks] masks_tensor = predicted_mask.pred_instances.masks masks_tensor = [mask.to(self.device) for mask in masks_tensor] # Compute bounding boxes and areas boxes = [cv2.boundingRect(mask) for mask in masks_binary] # Compute pairwise containment containments = np.zeros((len(masks), len(masks))) for i in range(len(masks)): box_i = boxes[i] for j in range(i + 1, len(masks)): box_j = boxes[j] if method == "mask": containment = self._calculate_containment_mask(masks_tensor[i], masks_tensor[j]) containments[i, j] = containment containment = self._calculate_containment_mask(masks_tensor[j], masks_tensor[i]) containments[j, i] = containment elif method == "bbox": containment = self._calculate_containment_bbox(box_i, box_j) containments[i, j] = containment containment = self._calculate_containment_bbox(box_j, box_i) containments[j, i] = containment # Keep only the biggest masks for overlapping pairs keep_mask = np.ones(len(masks), dtype=np.bool_) for i in range(len(masks)): if not keep_mask[i]: continue if np.any(containments[i] > containments_threshold): contained_indices = np.where(containments[i] > containments_threshold)[0] for j in contained_indices: if np.count_nonzero(masks_binary[i]) >= np.count_nonzero(masks_binary[j]): keep_mask[j] = False else: keep_mask[i] = False # Create a new DetDataSample with only selected instances filtered_result = DetDataSample(metainfo=predicted_mask.metainfo) pred_instances = InstanceData(metainfo=predicted_mask.metainfo) masks = [mask for i, mask in enumerate(masks) if keep_mask[i]] list_of_tensor_masks = [torch.from_numpy(mask) for mask in masks] stacked_masks = torch.stack(list_of_tensor_masks) updated_filtered_result = self._stacked_masks_update_data_sample( filtered_result, stacked_masks, pred_instances, keep_mask, predicted_mask ) return updated_filtered_result def _stacked_masks_update_data_sample(self, filtered_result, stacked_masks, pred_instances, keep_mask, result): pred_instances.masks = stacked_masks pred_instances.bboxes = self._update_datasample_cat(result.pred_instances.bboxes.tolist(), keep_mask) pred_instances.scores = self._update_datasample_cat(result.pred_instances.scores.tolist(), keep_mask) pred_instances.kernels = self._update_datasample_cat(result.pred_instances.kernels.tolist(), keep_mask) pred_instances.labels = self._update_datasample_cat(result.pred_instances.labels.tolist(), keep_mask) pred_instances.priors = self._update_datasample_cat(result.pred_instances.priors.tolist(), keep_mask) filtered_result.pred_instances = pred_instances return filtered_result def _calculate_containment_bbox(self, box_a, box_b): xA = max(box_a[0], box_b[0]) # max x0 yA = max(box_a[1], box_b[1]) # max y0 xB = min(box_a[0] + box_a[2], box_b[0] + box_b[2]) # min x1 yB = min(box_a[1] + box_a[3], box_b[1] + box_b[3]) # min y1 box_a_area = box_a[2] * box_a[3] box_b_area = box_b[2] * box_b[3] intersection_area = max(0, xB - xA + 1) * max(0, yB - yA + 1) containment = intersection_area / box_a_area if box_a_area > 0 else 0 return containment def _calculate_containment_mask(self, mask_a, mask_b): intersection = torch.logical_and(mask_a, mask_b).sum().float() containment = intersection / mask_b.sum().float() if mask_b.sum() > 0 else 0 return containment def _update_datasample_cat(self, cat_list, keep_mask): cat_keep = [cat for i, cat in enumerate(cat_list) if keep_mask[i]] tensor_cat_keep = torch.tensor(cat_keep) return tensor_cat_keep # @timer_func def filter_on_pred_threshold(self, result_pred, pred_score_threshold=0.5): id_list = [] for id, pred_score in enumerate(result_pred.pred_instances.scores): if pred_score > pred_score_threshold: id_list.append(id) # Create a new DetDataSample with only selected instances new_filtered_result = DetDataSample(metainfo=result_pred.metainfo) new_pred_instances = InstanceData(metainfo=result_pred.metainfo) new_pred_instances.masks = result_pred.pred_instances.masks[id_list] new_pred_instances.bboxes = result_pred.pred_instances.bboxes[id_list] new_pred_instances.scores = result_pred.pred_instances.scores[id_list] new_pred_instances.kernels = result_pred.pred_instances.kernels[id_list] new_pred_instances.labels = result_pred.pred_instances.labels[id_list] new_pred_instances.priors = result_pred.pred_instances.priors[id_list] new_filtered_result.pred_instances = new_pred_instances return new_filtered_result if __name__ == "__main__": pass