File size: 5,253 Bytes
2de1f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Literal, Union

def process_mmdet_results(mmdet_results: list,
                          cat_id: int = 0,
                          multi_person: bool = True) -> list:
    """Process mmdet results, sort bboxes by area in descending order.

    Args:
        mmdet_results (list):
            Result of mmdet.apis.inference_detector
            when the input is a batch.
            Shape of the nested lists is
            (n_frame, n_category, n_human, 5).
        cat_id (int, optional):
            Category ID. This function will only select
            the selected category, and drop the others.
            Defaults to 0, ID of human category.
        multi_person (bool, optional):
            Whether to allow multi-person detection, which is
            slower than single-person. If false, the function
            only assure that the first person of each frame
            has the biggest bbox.
            Defaults to True.

    Returns:
        list:
            A list of detected bounding boxes.
            Shape of the nested lists is
            (n_frame, n_human, 5)
            and each bbox is (x, y, x, y, score).
    """
    ret_list = []
    only_max_arg = not multi_person
    # for _, frame_results in enumerate(mmdet_results):
    cat_bboxes = mmdet_results[cat_id]
    # import pdb; pdb.set_trace()
    sorted_bbox = qsort_bbox_list(cat_bboxes, only_max_arg)

    if only_max_arg:
        ret_list.append(sorted_bbox[0:1])
    else:
        ret_list.append(sorted_bbox)
    return ret_list


def qsort_bbox_list(bbox_list: list,
                    only_max: bool = False,
                    bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy'):
    """Sort a list of bboxes, by their area in pixel(W*H).

    Args:
        input_list (list):
            A list of bboxes. Each item is a list of (x1, y1, x2, y2)
        only_max (bool, optional):
            If True, only assure the max element at first place,
            others may not be well sorted.
            If False, return a well sorted descending list.
            Defaults to False.
        bbox_convention (str, optional):
            Bbox type, xyxy or xywh. Defaults to 'xyxy'.

    Returns:
        list:
            A sorted(maybe not so well) descending list.
    """
    # import pdb; pdb.set_trace()
    if len(bbox_list) <= 1:
        return bbox_list
    else:
        bigger_list = []
        less_list = []
        anchor_index = int(len(bbox_list) / 2)
        anchor_bbox = bbox_list[anchor_index]
        anchor_area = get_area_of_bbox(anchor_bbox, bbox_convention)
        for i in range(len(bbox_list)):
            if i == anchor_index:
                continue
            tmp_bbox = bbox_list[i]
            tmp_area = get_area_of_bbox(tmp_bbox, bbox_convention)
            if tmp_area >= anchor_area:
                bigger_list.append(tmp_bbox)
            else:
                less_list.append(tmp_bbox)
        if only_max:
            return qsort_bbox_list(bigger_list) + \
                [anchor_bbox, ] + less_list
        else:
            return qsort_bbox_list(bigger_list) + \
                [anchor_bbox, ] + qsort_bbox_list(less_list)

def get_area_of_bbox(
        bbox: Union[list, tuple],
        bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy') -> float:
    """Get the area of a bbox_xyxy.

    Args:
        (Union[list, tuple]):
            A list of [x1, y1, x2, y2].
        bbox_convention (str, optional):
            Bbox type, xyxy or xywh. Defaults to 'xyxy'.

    Returns:
        float:
            Area of the bbox(|y2-y1|*|x2-x1|).
    """
    # import pdb;pdb.set_trace()
    if bbox_convention == 'xyxy':
        return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
    elif bbox_convention == 'xywh':
        return abs(bbox[2] * bbox[3])
    else:
        raise TypeError(f'Wrong bbox convention: {bbox_convention}')

def calculate_iou(bbox1, bbox2):
    # Calculate the Intersection over Union (IoU) between two bounding boxes
    x1 = max(bbox1[0], bbox2[0])
    y1 = max(bbox1[1], bbox2[1])
    x2 = min(bbox1[2], bbox2[2])
    y2 = min(bbox1[3], bbox2[3])
    
    intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
    
    bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
    bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
    
    union_area = bbox1_area + bbox2_area - intersection_area
    
    iou = intersection_area / union_area
    return iou


def non_max_suppression(bboxes, iou_threshold):
    # Sort the bounding boxes by their confidence scores (e.g., the probability of containing an object)
    bboxes = sorted(bboxes, key=lambda x: x[4], reverse=True)
    
    # Initialize a list to store the selected bounding boxes
    selected_bboxes = []
    
    # Perform non-maximum suppression
    while len(bboxes) > 0:
        current_bbox = bboxes[0]
        selected_bboxes.append(current_bbox)
        bboxes = bboxes[1:]
        
        remaining_bboxes = []
        for bbox in bboxes:
            iou = calculate_iou(current_bbox, bbox)
            if iou < iou_threshold:
                remaining_bboxes.append(bbox)
                
        bboxes = remaining_bboxes
        
    return selected_bboxes