Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import os.path as osp | |
| from collections import defaultdict | |
| from typing import Any, List, Tuple | |
| import mmengine.fileio as fileio | |
| from mmengine.dataset import BaseDataset | |
| from mmengine.logging import print_log | |
| from mmdet.datasets.api_wrappers import COCO | |
| from mmdet.registry import DATASETS | |
| class BaseVideoDataset(BaseDataset): | |
| """Base video dataset for VID, MOT and VIS tasks.""" | |
| META = dict(classes=None) | |
| # ann_id is unique in coco dataset. | |
| ANN_ID_UNIQUE = True | |
| def __init__(self, *args, backend_args: dict = None, **kwargs): | |
| self.backend_args = backend_args | |
| super().__init__(*args, **kwargs) | |
| def load_data_list(self) -> Tuple[List[dict], List]: | |
| """Load annotations from an annotation file named as ``self.ann_file``. | |
| Returns: | |
| tuple(list[dict], list): A list of annotation and a list of | |
| valid data indices. | |
| """ | |
| with fileio.get_local_path(self.ann_file) as local_path: | |
| self.coco = COCO(local_path) | |
| # The order of returned `cat_ids` will not | |
| # change with the order of the classes | |
| self.cat_ids = self.coco.get_cat_ids( | |
| cat_names=self.metainfo['classes']) | |
| self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} | |
| self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) | |
| # used in `filter_data` | |
| self.img_ids_with_ann = set() | |
| img_ids = self.coco.get_img_ids() | |
| total_ann_ids = [] | |
| # if ``video_id`` is not in the annotation file, we will assign a big | |
| # unique video_id for this video. | |
| single_video_id = 100000 | |
| videos = {} | |
| for img_id in img_ids: | |
| raw_img_info = self.coco.load_imgs([img_id])[0] | |
| raw_img_info['img_id'] = img_id | |
| if 'video_id' not in raw_img_info: | |
| single_video_id = single_video_id + 1 | |
| video_id = single_video_id | |
| else: | |
| video_id = raw_img_info['video_id'] | |
| if video_id not in videos: | |
| videos[video_id] = { | |
| 'video_id': video_id, | |
| 'images': [], | |
| 'video_length': 0 | |
| } | |
| videos[video_id]['video_length'] += 1 | |
| ann_ids = self.coco.get_ann_ids( | |
| img_ids=[img_id], cat_ids=self.cat_ids) | |
| raw_ann_info = self.coco.load_anns(ann_ids) | |
| total_ann_ids.extend(ann_ids) | |
| parsed_data_info = self.parse_data_info( | |
| dict(raw_img_info=raw_img_info, raw_ann_info=raw_ann_info)) | |
| if len(parsed_data_info['instances']) > 0: | |
| self.img_ids_with_ann.add(parsed_data_info['img_id']) | |
| videos[video_id]['images'].append(parsed_data_info) | |
| data_list = [v for v in videos.values()] | |
| if self.ANN_ID_UNIQUE: | |
| assert len(set(total_ann_ids)) == len( | |
| total_ann_ids | |
| ), f"Annotation ids in '{self.ann_file}' are not unique!" | |
| del self.coco | |
| return data_list | |
| def parse_data_info(self, raw_data_info: dict) -> dict: | |
| """Parse raw annotation to target format. | |
| Args: | |
| raw_data_info (dict): Raw data information loaded from | |
| ``ann_file``. | |
| Returns: | |
| dict: Parsed annotation. | |
| """ | |
| img_info = raw_data_info['raw_img_info'] | |
| ann_info = raw_data_info['raw_ann_info'] | |
| data_info = {} | |
| data_info.update(img_info) | |
| if self.data_prefix.get('img_path', None) is not None: | |
| img_path = osp.join(self.data_prefix['img_path'], | |
| img_info['file_name']) | |
| else: | |
| img_path = img_info['file_name'] | |
| data_info['img_path'] = img_path | |
| instances = [] | |
| for i, ann in enumerate(ann_info): | |
| instance = {} | |
| if ann.get('ignore', False): | |
| continue | |
| x1, y1, w, h = ann['bbox'] | |
| inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) | |
| inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) | |
| if inter_w * inter_h == 0: | |
| continue | |
| if ann['area'] <= 0 or w < 1 or h < 1: | |
| continue | |
| if ann['category_id'] not in self.cat_ids: | |
| continue | |
| bbox = [x1, y1, x1 + w, y1 + h] | |
| if ann.get('iscrowd', False): | |
| instance['ignore_flag'] = 1 | |
| else: | |
| instance['ignore_flag'] = 0 | |
| instance['bbox'] = bbox | |
| instance['bbox_label'] = self.cat2label[ann['category_id']] | |
| if ann.get('segmentation', None): | |
| instance['mask'] = ann['segmentation'] | |
| if ann.get('instance_id', None): | |
| instance['instance_id'] = ann['instance_id'] | |
| else: | |
| # image dataset usually has no `instance_id`. | |
| # Therefore, we set it to `i`. | |
| instance['instance_id'] = i | |
| instances.append(instance) | |
| data_info['instances'] = instances | |
| return data_info | |
| def filter_data(self) -> List[int]: | |
| """Filter image annotations according to filter_cfg. | |
| Returns: | |
| list[int]: Filtered results. | |
| """ | |
| if self.test_mode: | |
| return self.data_list | |
| num_imgs_before_filter = sum( | |
| [len(info['images']) for info in self.data_list]) | |
| num_imgs_after_filter = 0 | |
| # obtain images that contain annotations of the required categories | |
| ids_in_cat = set() | |
| for i, class_id in enumerate(self.cat_ids): | |
| ids_in_cat |= set(self.cat_img_map[class_id]) | |
| # merge the image id sets of the two conditions and use the merged set | |
| # to filter out images if self.filter_empty_gt=True | |
| ids_in_cat &= self.img_ids_with_ann | |
| new_data_list = [] | |
| for video_data_info in self.data_list: | |
| imgs_data_info = video_data_info['images'] | |
| valid_imgs_data_info = [] | |
| for data_info in imgs_data_info: | |
| img_id = data_info['img_id'] | |
| width = data_info['width'] | |
| height = data_info['height'] | |
| # TODO: simplify these conditions | |
| if self.filter_cfg is None: | |
| if img_id not in ids_in_cat: | |
| video_data_info['video_length'] -= 1 | |
| continue | |
| if min(width, height) >= 32: | |
| valid_imgs_data_info.append(data_info) | |
| num_imgs_after_filter += 1 | |
| else: | |
| video_data_info['video_length'] -= 1 | |
| else: | |
| if self.filter_cfg.get('filter_empty_gt', | |
| True) and img_id not in ids_in_cat: | |
| video_data_info['video_length'] -= 1 | |
| continue | |
| if min(width, height) >= self.filter_cfg.get( | |
| 'min_size', 32): | |
| valid_imgs_data_info.append(data_info) | |
| num_imgs_after_filter += 1 | |
| else: | |
| video_data_info['video_length'] -= 1 | |
| video_data_info['images'] = valid_imgs_data_info | |
| new_data_list.append(video_data_info) | |
| print_log( | |
| 'The number of samples before and after filtering: ' | |
| f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current') | |
| return new_data_list | |
| def prepare_data(self, idx) -> Any: | |
| """Get date processed by ``self.pipeline``. Note that ``idx`` is a | |
| video index in default since the base element of video dataset is a | |
| video. However, in some cases, we need to specific both the video index | |
| and frame index. For example, in traing mode, we may want to sample the | |
| specific frames and all the frames must be sampled once in a epoch; in | |
| test mode, we may want to output data of a single image rather than the | |
| whole video for saving memory. | |
| Args: | |
| idx (int): The index of ``data_info``. | |
| Returns: | |
| Any: Depends on ``self.pipeline``. | |
| """ | |
| if isinstance(idx, tuple): | |
| assert len(idx) == 2, 'The length of idx must be 2: ' | |
| '(video_index, frame_index)' | |
| video_idx, frame_idx = idx[0], idx[1] | |
| else: | |
| video_idx, frame_idx = idx, None | |
| data_info = self.get_data_info(video_idx) | |
| if self.test_mode: | |
| # Support two test_mode: frame-level and video-level | |
| final_data_info = defaultdict(list) | |
| if frame_idx is None: | |
| frames_idx_list = list(range(data_info['video_length'])) | |
| else: | |
| frames_idx_list = [frame_idx] | |
| for index in frames_idx_list: | |
| frame_ann = data_info['images'][index] | |
| frame_ann['video_id'] = data_info['video_id'] | |
| # Collate data_list (list of dict to dict of list) | |
| for key, value in frame_ann.items(): | |
| final_data_info[key].append(value) | |
| # copy the info in video-level into img-level | |
| # TODO: the value of this key is the same as that of | |
| # `video_length` in test mode | |
| final_data_info['ori_video_length'].append( | |
| data_info['video_length']) | |
| final_data_info['video_length'] = [len(frames_idx_list) | |
| ] * len(frames_idx_list) | |
| return self.pipeline(final_data_info) | |
| else: | |
| # Specify `key_frame_id` for the frame sampling in the pipeline | |
| if frame_idx is not None: | |
| data_info['key_frame_id'] = frame_idx | |
| return self.pipeline(data_info) | |
| def get_cat_ids(self, index) -> List[int]: | |
| """Following image detection, we provide this interface function. Get | |
| category ids by video index and frame index. | |
| Args: | |
| index: The index of the dataset. It support two kinds of inputs: | |
| Tuple: | |
| video_idx (int): Index of video. | |
| frame_idx (int): Index of frame. | |
| Int: Index of video. | |
| Returns: | |
| List[int]: All categories in the image of specified video index | |
| and frame index. | |
| """ | |
| if isinstance(index, tuple): | |
| assert len( | |
| index | |
| ) == 2, f'Expect the length of index is 2, but got {len(index)}' | |
| video_idx, frame_idx = index | |
| instances = self.get_data_info( | |
| video_idx)['images'][frame_idx]['instances'] | |
| return [instance['bbox_label'] for instance in instances] | |
| else: | |
| cat_ids = [] | |
| for img in self.get_data_info(index)['images']: | |
| for instance in img['instances']: | |
| cat_ids.append(instance['bbox_label']) | |
| return cat_ids | |
| def num_all_imgs(self): | |
| """Get the number of all the images in this video dataset.""" | |
| return sum( | |
| [len(self.get_data_info(i)['images']) for i in range(len(self))]) | |
| def get_len_per_video(self, idx): | |
| """Get length of one video. | |
| Args: | |
| idx (int): Index of video. | |
| Returns: | |
| int (int): The length of the video. | |
| """ | |
| return len(self.get_data_info(idx)['images']) | |