| |
| import json |
| import logging |
| import os.path as osp |
| import warnings |
| from typing import List, Union |
|
|
| import mmcv |
| from mmengine.dist import get_rank |
| from mmengine.fileio import dump, get, get_text, load |
| from mmengine.logging import print_log |
| from mmengine.utils import ProgressBar |
|
|
| from mmdet.registry import DATASETS |
| from .base_det_dataset import BaseDetDataset |
|
|
|
|
| @DATASETS.register_module() |
| class CrowdHumanDataset(BaseDetDataset): |
| r"""Dataset for CrowdHuman. |
| |
| Args: |
| data_root (str): The root directory for |
| ``data_prefix`` and ``ann_file``. |
| ann_file (str): Annotation file path. |
| extra_ann_file (str | optional):The path of extra image metas |
| for CrowdHuman. It can be created by CrowdHumanDataset |
| automatically or by tools/misc/get_crowdhuman_id_hw.py |
| manually. Defaults to None. |
| """ |
|
|
| METAINFO = { |
| 'classes': ('person', ), |
| |
| 'palette': [(220, 20, 60)] |
| } |
|
|
| def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs): |
| |
| |
| |
| if extra_ann_file is not None: |
| self.extra_ann_exist = True |
| self.extra_anns = load(extra_ann_file) |
| else: |
| ann_file_name = osp.basename(ann_file) |
| if 'train' in ann_file_name: |
| self.extra_ann_file = osp.join(data_root, 'id_hw_train.json') |
| elif 'val' in ann_file_name: |
| self.extra_ann_file = osp.join(data_root, 'id_hw_val.json') |
| self.extra_ann_exist = False |
| if not osp.isfile(self.extra_ann_file): |
| print_log( |
| 'extra_ann_file does not exist, prepare to collect ' |
| 'image height and width...', |
| level=logging.INFO) |
| self.extra_anns = {} |
| else: |
| self.extra_ann_exist = True |
| self.extra_anns = load(self.extra_ann_file) |
| super().__init__(data_root=data_root, ann_file=ann_file, **kwargs) |
|
|
| def load_data_list(self) -> List[dict]: |
| """Load annotations from an annotation file named as ``self.ann_file`` |
| |
| Returns: |
| List[dict]: A list of annotation. |
| """ |
| anno_strs = get_text( |
| self.ann_file, backend_args=self.backend_args).strip().split('\n') |
| print_log('loading CrowdHuman annotation...', level=logging.INFO) |
| data_list = [] |
| prog_bar = ProgressBar(len(anno_strs)) |
| for i, anno_str in enumerate(anno_strs): |
| anno_dict = json.loads(anno_str) |
| parsed_data_info = self.parse_data_info(anno_dict) |
| data_list.append(parsed_data_info) |
| prog_bar.update() |
| if not self.extra_ann_exist and get_rank() == 0: |
| |
| try: |
| dump(self.extra_anns, self.extra_ann_file, file_format='json') |
| except: |
| warnings.warn( |
| 'Cache files can not be saved automatically! To speed up' |
| 'loading the dataset, please manually generate the cache' |
| ' file by file tools/misc/get_crowdhuman_id_hw.py') |
|
|
| print_log( |
| f'\nsave extra_ann_file in {self.data_root}', |
| level=logging.INFO) |
|
|
| del self.extra_anns |
| print_log('\nDone', level=logging.INFO) |
| return data_list |
|
|
| def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: |
| """Parse raw annotation to target format. |
| |
| Args: |
| raw_data_info (dict): Raw data information load from ``ann_file`` |
| |
| Returns: |
| Union[dict, List[dict]]: Parsed annotation. |
| """ |
| data_info = {} |
| img_path = osp.join(self.data_prefix['img'], |
| f"{raw_data_info['ID']}.jpg") |
| data_info['img_path'] = img_path |
| data_info['img_id'] = raw_data_info['ID'] |
|
|
| if not self.extra_ann_exist: |
| img_bytes = get(img_path, backend_args=self.backend_args) |
| img = mmcv.imfrombytes(img_bytes, backend='cv2') |
| data_info['height'], data_info['width'] = img.shape[:2] |
| self.extra_anns[raw_data_info['ID']] = img.shape[:2] |
| del img, img_bytes |
| else: |
| data_info['height'], data_info['width'] = self.extra_anns[ |
| raw_data_info['ID']] |
|
|
| instances = [] |
| for i, ann in enumerate(raw_data_info['gtboxes']): |
| instance = {} |
| if ann['tag'] not in self.metainfo['classes']: |
| instance['bbox_label'] = -1 |
| instance['ignore_flag'] = 1 |
| else: |
| instance['bbox_label'] = self.metainfo['classes'].index( |
| ann['tag']) |
| instance['ignore_flag'] = 0 |
| if 'extra' in ann: |
| if 'ignore' in ann['extra']: |
| if ann['extra']['ignore'] != 0: |
| instance['bbox_label'] = -1 |
| instance['ignore_flag'] = 1 |
|
|
| x1, y1, w, h = ann['fbox'] |
| bbox = [x1, y1, x1 + w, y1 + h] |
| instance['bbox'] = bbox |
|
|
| |
| |
| |
| |
| instance['fbox'] = bbox |
| hbox = ann['hbox'] |
| instance['hbox'] = [ |
| hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3] |
| ] |
| vbox = ann['vbox'] |
| instance['vbox'] = [ |
| vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3] |
| ] |
|
|
| instances.append(instance) |
|
|
| data_info['instances'] = instances |
| return data_info |
|
|