# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os.path as osp from typing import List, Union from mmdet.registry import DATASETS from .base_video_dataset import BaseVideoDataset @DATASETS.register_module() class MOTChallengeDataset(BaseVideoDataset): """Dataset for MOTChallenge. Args: visibility_thr (float, optional): The minimum visibility for the objects during training. Default to -1. """ METAINFO = { 'classes': ('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike', 'non_mot_vehicle', 'static_person', 'distractor', 'occluder', 'occluder_on_ground', 'occluder_full', 'reflection', 'crowd') } def __init__(self, visibility_thr: float = -1, *args, **kwargs): self.visibility_thr = visibility_thr super().__init__(*args, **kwargs) def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. The difference between this function and the one in ``BaseVideoDataset`` is that the parsing here adds ``visibility`` and ``mot_conf``. Args: raw_data_info (dict): Raw data information load from ``ann_file`` Returns: Union[dict, List[dict]]: Parsed annotation. """ img_info = raw_data_info['raw_img_info'] ann_info = raw_data_info['raw_ann_info'] data_info = {} data_info.update(img_info) if self.data_prefix.get('img_path', None) is not None: img_path = osp.join(self.data_prefix['img_path'], img_info['file_name']) else: img_path = img_info['file_name'] data_info['img_path'] = img_path instances = [] for i, ann in enumerate(ann_info): instance = {} if (not self.test_mode) and (ann['visibility'] < self.visibility_thr): continue if ann.get('ignore', False): continue x1, y1, w, h = ann['bbox'] inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) if inter_w * inter_h == 0: continue if ann['area'] <= 0 or w < 1 or h < 1: continue if ann['category_id'] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): instance['ignore_flag'] = 1 else: instance['ignore_flag'] = 0 instance['bbox'] = bbox instance['bbox_label'] = self.cat2label[ann['category_id']] instance['instance_id'] = ann['instance_id'] instance['category_id'] = ann['category_id'] instance['mot_conf'] = ann['mot_conf'] instance['visibility'] = ann['visibility'] if len(instance) > 0: instances.append(instance) if not self.test_mode: assert len(instances) > 0, f'No valid instances found in ' \ f'image {data_info["img_path"]}!' data_info['instances'] = instances return data_info