# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from os import PathLike from typing import List, Optional, Sequence, Union import mmengine import numpy as np from mmengine.dataset import BaseDataset as _BaseDataset from mmpretrain.registry import DATASETS, TRANSFORMS def expanduser(path): """Expand ~ and ~user constructions. If user or $HOME is unknown, do nothing. """ if isinstance(path, (str, PathLike)): return osp.expanduser(path) else: return path @DATASETS.register_module() class BaseDataset(_BaseDataset): """Base dataset for image classification task. This dataset support annotation file in `OpenMMLab 2.0 style annotation format`. .. _OpenMMLab 2.0 style annotation format: https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md Comparing with the :class:`mmengine.BaseDataset`, this class implemented several useful methods. Args: ann_file (str): Annotation file path. metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. data_root (str): The root directory for ``data_prefix`` and ``ann_file``. Defaults to ''. data_prefix (str | dict): Prefix for training data. Defaults to ''. filter_cfg (dict, optional): Config for filter data. Defaults to None. indices (int or Sequence[int], optional): Support using first few data in annotation file to facilitate training/testing on a smaller dataset. Defaults to None, which means using all ``data_infos``. serialize_data (bool): Whether to hold memory using serialized objects, when enabled, data loader workers can use shared RAM from master process instead of making a copy. Defaults to True. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. test_mode (bool, optional): ``test_mode=True`` means in test phase, an error will be raised when getting an item fails, ``test_mode=False`` means in training phase, another item will be returned randomly. Defaults to False. lazy_init (bool): Whether to load annotation during instantiation. In some cases, such as visualization, only the meta information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=False``. Defaults to False. max_refetch (int): If ``Basedataset.prepare_data`` get a None img. The maximum extra number of cycles to get a valid image. Defaults to 1000. classes (str | Sequence[str], optional): Specify names of classes. - If is string, it should be a file path, and the every line of the file is a name of a class. - If is a sequence of string, every item is a name of class. - If is None, use categories information in ``metainfo`` argument, annotation file or the class attribute ``METAINFO``. Defaults to None. """ # noqa: E501 def __init__(self, ann_file: str, metainfo: Optional[dict] = None, data_root: str = '', data_prefix: Union[str, dict] = '', filter_cfg: Optional[dict] = None, indices: Optional[Union[int, Sequence[int]]] = None, serialize_data: bool = True, pipeline: Sequence = (), test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000, classes: Union[str, Sequence[str], None] = None): if isinstance(data_prefix, str): data_prefix = dict(img_path=expanduser(data_prefix)) ann_file = expanduser(ann_file) metainfo = self._compat_classes(metainfo, classes) transforms = [] for transform in pipeline: if isinstance(transform, dict): transforms.append(TRANSFORMS.build(transform)) else: transforms.append(transform) super().__init__( ann_file=ann_file, metainfo=metainfo, data_root=data_root, data_prefix=data_prefix, filter_cfg=filter_cfg, indices=indices, serialize_data=serialize_data, pipeline=transforms, test_mode=test_mode, lazy_init=lazy_init, max_refetch=max_refetch) @property def img_prefix(self): """The prefix of images.""" return self.data_prefix['img_path'] @property def CLASSES(self): """Return all categories names.""" return self._metainfo.get('classes', None) @property def class_to_idx(self): """Map mapping class name to class index. Returns: dict: mapping from class name to class index. """ return {cat: i for i, cat in enumerate(self.CLASSES)} def get_gt_labels(self): """Get all ground-truth labels (categories). Returns: np.ndarray: categories for all images. """ gt_labels = np.array( [self.get_data_info(i)['gt_label'] for i in range(len(self))]) return gt_labels def get_cat_ids(self, idx: int) -> List[int]: """Get category id by index. Args: idx (int): Index of data. Returns: cat_ids (List[int]): Image category of specified index. """ return [int(self.get_data_info(idx)['gt_label'])] def _compat_classes(self, metainfo, classes): """Merge the old style ``classes`` arguments to ``metainfo``.""" if isinstance(classes, str): # take it as a file path class_names = mmengine.list_from_file(expanduser(classes)) elif isinstance(classes, (tuple, list)): class_names = classes elif classes is not None: raise ValueError(f'Unsupported type {type(classes)} of classes.') if metainfo is None: metainfo = {} if classes is not None: metainfo = {'classes': tuple(class_names), **metainfo} return metainfo def full_init(self): """Load annotation file and set ``BaseDataset._fully_initialized`` to True.""" super().full_init() # To support the standard OpenMMLab 2.0 annotation format. Generate # metainfo in internal format from standard metainfo format. if 'categories' in self._metainfo and 'classes' not in self._metainfo: categories = sorted( self._metainfo['categories'], key=lambda x: x['id']) self._metainfo['classes'] = tuple( [cat['category_name'] for cat in categories]) def __repr__(self): """Print the basic information of the dataset. Returns: str: Formatted string. """ head = 'Dataset ' + self.__class__.__name__ body = [] if self._fully_initialized: body.append(f'Number of samples: \t{self.__len__()}') else: body.append("Haven't been initialized") if self.CLASSES is not None: body.append(f'Number of categories: \t{len(self.CLASSES)}') body.extend(self.extra_repr()) if len(self.pipeline.transforms) > 0: body.append('With transforms:') for t in self.pipeline.transforms: body.append(f' {t}') lines = [head] + [' ' * 4 + line for line in body] return '\n'.join(lines) def extra_repr(self) -> List[str]: """The extra repr information of the dataset.""" body = [] body.append(f'Annotation file: \t{self.ann_file}') body.append(f'Prefix of images: \t{self.img_prefix}') return body