import numpy as np def cxywh2xywh(cx, cy, w, h): """CxCyWH format to XYWH format conversion""" x = cx - w / 2 y = cy - h / 2 return x, y, w, h def cxywh2ltrb(cx, cy, w, h): """CxCyWH format to LeftRightTopBottom format""" l = cx - w / 2 t = cy - h / 2 r = cx + w / 2 b = cy + h / 2 return l, t, r, b def iou(ba, bb): """Calculate Intersection-Over-Union Args: ba (tuple): CxCyWH format with score bb (tuple): CxCyWH format with score Returns: IoU with size of length of given box """ a_l, a_t, a_r, a_b, sa = ba b_l, b_t, b_r, b_b, sb = bb x1 = np.maximum(a_l, b_l) y1 = np.maximum(a_t, b_t) x2 = np.minimum(a_r, b_r) y2 = np.minimum(a_b, b_b) w = np.maximum(0, x2 - x1) h = np.maximum(0, y2 - y1) intersec = w * h iou = (intersec) / (sa + sb - intersec) return iou.squeeze() def nms(cx, cy, w, h, s, iou_thresh=0.3): """Bounding box Non-maximum Suppression Args: cx, cy, w, h, s: CxCyWH Format with score boxes iou_thresh (float, optional): IoU threshold. Defaults to 0.3. Returns: res: indexes of the selected boxes """ l, t, r, b = cxywh2ltrb(cx, cy, w, h) areas = w * h res = [] sort_ind = np.argsort(s, axis=-1)[::-1] while sort_ind.shape[0] > 0: i = sort_ind[0] res.append(i) _iou = iou( (l[i], t[i], r[i], b[i], areas[i]), ( l[sort_ind[1:]], t[sort_ind[1:]], r[sort_ind[1:]], b[sort_ind[1:]], areas[sort_ind[1:]], ), ) sel_ind = np.where(_iou <= iou_thresh)[0] sort_ind = sort_ind[sel_ind + 1] return res def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7): """filter out insignificant boxes Args: boxes (list of records): returned query to be filtered """ ret = [] labelwise = {} for b in boxes: _id, cx, cy, w, h, label, logit, is_selected = b[:8] if label not in labelwise: labelwise[label] = [] labelwise[label].append(logit) labelwise = {l: max(s) for l, s in labelwise.items()} agnostic = max([v for _, v in labelwise.items()]) for b in boxes: _id, cx, cy, w, h, label, logit, is_selected = b[:8] if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic: ret.append(b) return ret def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7): meta = [] boxes_w_img = [] matches_ = {m["img_id"]: m for m in matches} if img_matches is not None: img_matches_ = {m["img_id"]: m for m in img_matches} for k in matches_.keys(): m = matches_[k] boxes = [] boxes += list( map( list, zip( m["box_id"], m["cx"], m["cy"], m["w"], m["h"], [prompt_labels[int(l)] for l in m["label"]], m["logit"], [1] * len(m["box_id"]), ), ) ) if img_matches is not None and k in img_matches_: img_m = img_matches_[k] # and also those non-TopK hits and those non-topk are not anticipating training boxes += [ i for i in map( list, zip( img_m["box_id"], img_m["cx"], img_m["cy"], img_m["w"], img_m["h"], [prompt_labels[int(l)] for l in img_m["label"]], img_m["logit"], [0] * len(img_m["box_id"]), ), ) if i[0] not in [b[0] for b in boxes] ] else: img_m = None # update record metadata after query for b in boxes: meta.append(b[0]) # remove some non-significant boxes boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio) # doing non-maximum suppression cx, cy, w, h, s = list( map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes]))) ) ind = nms(cx, cy, w, h, s, 0.3) boxes = [boxes[i] for i in ind] if img_m is not None: img_score = ( img_m["img_score"] if img_matches is not None else m["img_score"] ) boxes_w_img.append( (m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes) ) return boxes_w_img, meta