File size: 4,873 Bytes
3f1124e
 
 
 
88c0383
3f1124e
 
 
 
 
 
88c0383
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c0383
 
 
 
 
 
 
 
 
 
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
88c0383
 
3f1124e
 
 
 
 
 
88c0383
 
3f1124e
 
 
 
88c0383
3f1124e
 
88c0383
3f1124e
88c0383
3f1124e
 
 
88c0383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1a5f04
3f1124e
 
88c0383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1a5f04
 
3f1124e
 
 
 
 
88c0383
3f1124e
 
88c0383
 
 
3f1124e
 
f1a5f04
88c0383
 
 
f1a5f04
88c0383
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np


def cxywh2xywh(cx, cy, w, h):
    """CxCyWH format to XYWH format conversion"""
    x = cx - w / 2
    y = cy - h / 2
    return x, y, w, h


def cxywh2ltrb(cx, cy, w, h):
    """CxCyWH format to LeftRightTopBottom format"""
    l = cx - w / 2
    t = cy - h / 2
    r = cx + w / 2
    b = cy + h / 2
    return l, t, r, b


def iou(ba, bb):
    """Calculate Intersection-Over-Union

    Args:
        ba (tuple): CxCyWH format with score
        bb (tuple): CxCyWH format with score

    Returns:
        IoU with size of length of given box
    """
    a_l, a_t, a_r, a_b, sa = ba
    b_l, b_t, b_r, b_b, sb = bb

    x1 = np.maximum(a_l, b_l)
    y1 = np.maximum(a_t, b_t)
    x2 = np.minimum(a_r, b_r)
    y2 = np.minimum(a_b, b_b)
    w = np.maximum(0, x2 - x1)
    h = np.maximum(0, y2 - y1)
    intersec = w * h
    iou = (intersec) / (sa + sb - intersec)
    return iou.squeeze()


def nms(cx, cy, w, h, s, iou_thresh=0.3):
    """Bounding box Non-maximum Suppression

    Args:
        cx, cy, w, h, s: CxCyWH Format with score boxes
        iou_thresh (float, optional): IoU threshold. Defaults to 0.3.

    Returns:
        res: indexes of the selected boxes
    """
    l, t, r, b = cxywh2ltrb(cx, cy, w, h)
    areas = w * h
    res = []
    sort_ind = np.argsort(s, axis=-1)[::-1]
    while sort_ind.shape[0] > 0:
        i = sort_ind[0]
        res.append(i)

        _iou = iou(
            (l[i], t[i], r[i], b[i], areas[i]),
            (
                l[sort_ind[1:]],
                t[sort_ind[1:]],
                r[sort_ind[1:]],
                b[sort_ind[1:]],
                areas[sort_ind[1:]],
            ),
        )
        sel_ind = np.where(_iou <= iou_thresh)[0]
        sort_ind = sort_ind[sel_ind + 1]
    return res


def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
    """filter out insignificant boxes

    Args:
        boxes (list of records): returned query to be filtered
    """
    ret = []
    labelwise = {}
    for b in boxes:
        _id, cx, cy, w, h, label, logit, is_selected = b[:8]
        if label not in labelwise:
            labelwise[label] = []
        labelwise[label].append(logit)
    labelwise = {l: max(s) for l, s in labelwise.items()}
    agnostic = max([v for _, v in labelwise.items()])
    for b in boxes:
        _id, cx, cy, w, h, label, logit, is_selected = b[:8]
        if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
            ret.append(b)
    return ret


def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
    meta = []
    boxes_w_img = []
    matches_ = {m["img_id"]: m for m in matches}
    if img_matches is not None:
        img_matches_ = {m["img_id"]: m for m in img_matches}
    for k in matches_.keys():
        m = matches_[k]
        boxes = []
        boxes += list(
            map(
                list,
                zip(
                    m["box_id"],
                    m["cx"],
                    m["cy"],
                    m["w"],
                    m["h"],
                    [prompt_labels[int(l)] for l in m["label"]],
                    m["logit"],
                    [1] * len(m["box_id"]),
                ),
            )
        )
        if img_matches is not None and k in img_matches_:
            img_m = img_matches_[k]
            # and also those non-TopK hits and those non-topk are not anticipating training
            boxes += [
                i
                for i in map(
                    list,
                    zip(
                        img_m["box_id"],
                        img_m["cx"],
                        img_m["cy"],
                        img_m["w"],
                        img_m["h"],
                        [prompt_labels[int(l)] for l in img_m["label"]],
                        img_m["logit"],
                        [0] * len(img_m["box_id"]),
                    ),
                )
                if i[0] not in [b[0] for b in boxes]
            ]
        else:
            img_m = None
        # update record metadata after query
        for b in boxes:
            meta.append(b[0])

        # remove some non-significant boxes
        boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)

        # doing non-maximum suppression
        cx, cy, w, h, s = list(
            map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
        )
        ind = nms(cx, cy, w, h, s, 0.3)
        boxes = [boxes[i] for i in ind]
        if img_m is not None:
            img_score = (
                img_m["img_score"] if img_matches is not None else m["img_score"]
            )
            boxes_w_img.append(
                (m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
            )
    return boxes_w_img, meta