from typing import List, Tuple, Union import cv2 from numpy import ndarray MAJOR, MINOR = map(int, cv2.__version__.split('.')[:2]) assert MAJOR == 4 def non_max_suppression(boxes: Union[List[ndarray], Tuple[ndarray]], scores: Union[List[float], Tuple[float]], labels: Union[List[int], Tuple[int]], conf_thres: float = 0.25, iou_thres: float = 0.65) -> Tuple[List, List, List]: if MINOR >= 7: indices = cv2.dnn.NMSBoxesBatched(boxes, scores, labels, conf_thres, iou_thres) elif MINOR == 6: indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres) else: indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres).flatten() nmsd_boxes = [] nmsd_scores = [] nmsd_labels = [] for idx in indices: box = boxes[idx] # x0y0wh -> x0y0x1y1 box[2:] = box[:2] + box[2:] score = scores[idx] label = labels[idx] nmsd_boxes.append(box) nmsd_scores.append(score) nmsd_labels.append(label) return nmsd_boxes, nmsd_scores, nmsd_labels