Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os.path as osp | |
from typing import List, Union | |
from mmdet.registry import DATASETS | |
from .base_video_dataset import BaseVideoDataset | |
class MOTChallengeDataset(BaseVideoDataset): | |
"""Dataset for MOTChallenge. | |
Args: | |
visibility_thr (float, optional): The minimum visibility | |
for the objects during training. Default to -1. | |
""" | |
METAINFO = { | |
'classes': | |
('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike', | |
'non_mot_vehicle', 'static_person', 'distractor', 'occluder', | |
'occluder_on_ground', 'occluder_full', 'reflection', 'crowd') | |
} | |
def __init__(self, visibility_thr: float = -1, *args, **kwargs): | |
self.visibility_thr = visibility_thr | |
super().__init__(*args, **kwargs) | |
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: | |
"""Parse raw annotation to target format. The difference between this | |
function and the one in ``BaseVideoDataset`` is that the parsing here | |
adds ``visibility`` and ``mot_conf``. | |
Args: | |
raw_data_info (dict): Raw data information load from ``ann_file`` | |
Returns: | |
Union[dict, List[dict]]: Parsed annotation. | |
""" | |
img_info = raw_data_info['raw_img_info'] | |
ann_info = raw_data_info['raw_ann_info'] | |
data_info = {} | |
data_info.update(img_info) | |
if self.data_prefix.get('img_path', None) is not None: | |
img_path = osp.join(self.data_prefix['img_path'], | |
img_info['file_name']) | |
else: | |
img_path = img_info['file_name'] | |
data_info['img_path'] = img_path | |
instances = [] | |
for i, ann in enumerate(ann_info): | |
instance = {} | |
if (not self.test_mode) and (ann['visibility'] < | |
self.visibility_thr): | |
continue | |
if ann.get('ignore', False): | |
continue | |
x1, y1, w, h = ann['bbox'] | |
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) | |
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) | |
if inter_w * inter_h == 0: | |
continue | |
if ann['area'] <= 0 or w < 1 or h < 1: | |
continue | |
if ann['category_id'] not in self.cat_ids: | |
continue | |
bbox = [x1, y1, x1 + w, y1 + h] | |
if ann.get('iscrowd', False): | |
instance['ignore_flag'] = 1 | |
else: | |
instance['ignore_flag'] = 0 | |
instance['bbox'] = bbox | |
instance['bbox_label'] = self.cat2label[ann['category_id']] | |
instance['instance_id'] = ann['instance_id'] | |
instance['category_id'] = ann['category_id'] | |
instance['mot_conf'] = ann['mot_conf'] | |
instance['visibility'] = ann['visibility'] | |
if len(instance) > 0: | |
instances.append(instance) | |
if not self.test_mode: | |
assert len(instances) > 0, f'No valid instances found in ' \ | |
f'image {data_info["img_path"]}!' | |
data_info['instances'] = instances | |
return data_info | |