object-detection-safari / box_utils.py
mpsk's picture
sql-compute-grad (#2)
88c0383
import numpy as np
def cxywh2xywh(cx, cy, w, h):
"""CxCyWH format to XYWH format conversion"""
x = cx - w / 2
y = cy - h / 2
return x, y, w, h
def cxywh2ltrb(cx, cy, w, h):
"""CxCyWH format to LeftRightTopBottom format"""
l = cx - w / 2
t = cy - h / 2
r = cx + w / 2
b = cy + h / 2
return l, t, r, b
def iou(ba, bb):
"""Calculate Intersection-Over-Union
Args:
ba (tuple): CxCyWH format with score
bb (tuple): CxCyWH format with score
Returns:
IoU with size of length of given box
"""
a_l, a_t, a_r, a_b, sa = ba
b_l, b_t, b_r, b_b, sb = bb
x1 = np.maximum(a_l, b_l)
y1 = np.maximum(a_t, b_t)
x2 = np.minimum(a_r, b_r)
y2 = np.minimum(a_b, b_b)
w = np.maximum(0, x2 - x1)
h = np.maximum(0, y2 - y1)
intersec = w * h
iou = (intersec) / (sa + sb - intersec)
return iou.squeeze()
def nms(cx, cy, w, h, s, iou_thresh=0.3):
"""Bounding box Non-maximum Suppression
Args:
cx, cy, w, h, s: CxCyWH Format with score boxes
iou_thresh (float, optional): IoU threshold. Defaults to 0.3.
Returns:
res: indexes of the selected boxes
"""
l, t, r, b = cxywh2ltrb(cx, cy, w, h)
areas = w * h
res = []
sort_ind = np.argsort(s, axis=-1)[::-1]
while sort_ind.shape[0] > 0:
i = sort_ind[0]
res.append(i)
_iou = iou(
(l[i], t[i], r[i], b[i], areas[i]),
(
l[sort_ind[1:]],
t[sort_ind[1:]],
r[sort_ind[1:]],
b[sort_ind[1:]],
areas[sort_ind[1:]],
),
)
sel_ind = np.where(_iou <= iou_thresh)[0]
sort_ind = sort_ind[sel_ind + 1]
return res
def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
"""filter out insignificant boxes
Args:
boxes (list of records): returned query to be filtered
"""
ret = []
labelwise = {}
for b in boxes:
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
if label not in labelwise:
labelwise[label] = []
labelwise[label].append(logit)
labelwise = {l: max(s) for l, s in labelwise.items()}
agnostic = max([v for _, v in labelwise.items()])
for b in boxes:
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
ret.append(b)
return ret
def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
meta = []
boxes_w_img = []
matches_ = {m["img_id"]: m for m in matches}
if img_matches is not None:
img_matches_ = {m["img_id"]: m for m in img_matches}
for k in matches_.keys():
m = matches_[k]
boxes = []
boxes += list(
map(
list,
zip(
m["box_id"],
m["cx"],
m["cy"],
m["w"],
m["h"],
[prompt_labels[int(l)] for l in m["label"]],
m["logit"],
[1] * len(m["box_id"]),
),
)
)
if img_matches is not None and k in img_matches_:
img_m = img_matches_[k]
# and also those non-TopK hits and those non-topk are not anticipating training
boxes += [
i
for i in map(
list,
zip(
img_m["box_id"],
img_m["cx"],
img_m["cy"],
img_m["w"],
img_m["h"],
[prompt_labels[int(l)] for l in img_m["label"]],
img_m["logit"],
[0] * len(img_m["box_id"]),
),
)
if i[0] not in [b[0] for b in boxes]
]
else:
img_m = None
# update record metadata after query
for b in boxes:
meta.append(b[0])
# remove some non-significant boxes
boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)
# doing non-maximum suppression
cx, cy, w, h, s = list(
map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
)
ind = nms(cx, cy, w, h, s, 0.3)
boxes = [boxes[i] for i in ind]
if img_m is not None:
img_score = (
img_m["img_score"] if img_matches is not None else m["img_score"]
)
boxes_w_img.append(
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
)
return boxes_w_img, meta