Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
from os import PathLike | |
from typing import List, Optional, Sequence, Union | |
import mmengine | |
import numpy as np | |
from mmengine.dataset import BaseDataset as _BaseDataset | |
from mmpretrain.registry import DATASETS, TRANSFORMS | |
def expanduser(path): | |
"""Expand ~ and ~user constructions. | |
If user or $HOME is unknown, do nothing. | |
""" | |
if isinstance(path, (str, PathLike)): | |
return osp.expanduser(path) | |
else: | |
return path | |
class BaseDataset(_BaseDataset): | |
"""Base dataset for image classification task. | |
This dataset support annotation file in `OpenMMLab 2.0 style annotation | |
format`. | |
.. _OpenMMLab 2.0 style annotation format: | |
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md | |
Comparing with the :class:`mmengine.BaseDataset`, this class implemented | |
several useful methods. | |
Args: | |
ann_file (str): Annotation file path. | |
metainfo (dict, optional): Meta information for dataset, such as class | |
information. Defaults to None. | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. Defaults to ''. | |
data_prefix (str | dict): Prefix for training data. Defaults to ''. | |
filter_cfg (dict, optional): Config for filter data. Defaults to None. | |
indices (int or Sequence[int], optional): Support using first few | |
data in annotation file to facilitate training/testing on a smaller | |
dataset. Defaults to None, which means using all ``data_infos``. | |
serialize_data (bool): Whether to hold memory using serialized objects, | |
when enabled, data loader workers can use shared RAM from master | |
process instead of making a copy. Defaults to True. | |
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. | |
test_mode (bool, optional): ``test_mode=True`` means in test phase, | |
an error will be raised when getting an item fails, ``test_mode=False`` | |
means in training phase, another item will be returned randomly. | |
Defaults to False. | |
lazy_init (bool): Whether to load annotation during instantiation. | |
In some cases, such as visualization, only the meta information of | |
the dataset is needed, which is not necessary to load annotation | |
file. ``Basedataset`` can skip load annotations to save time by set | |
``lazy_init=False``. Defaults to False. | |
max_refetch (int): If ``Basedataset.prepare_data`` get a None img. | |
The maximum extra number of cycles to get a valid image. | |
Defaults to 1000. | |
classes (str | Sequence[str], optional): Specify names of classes. | |
- If is string, it should be a file path, and the every line of | |
the file is a name of a class. | |
- If is a sequence of string, every item is a name of class. | |
- If is None, use categories information in ``metainfo`` argument, | |
annotation file or the class attribute ``METAINFO``. | |
Defaults to None. | |
""" # noqa: E501 | |
def __init__(self, | |
ann_file: str, | |
metainfo: Optional[dict] = None, | |
data_root: str = '', | |
data_prefix: Union[str, dict] = '', | |
filter_cfg: Optional[dict] = None, | |
indices: Optional[Union[int, Sequence[int]]] = None, | |
serialize_data: bool = True, | |
pipeline: Sequence = (), | |
test_mode: bool = False, | |
lazy_init: bool = False, | |
max_refetch: int = 1000, | |
classes: Union[str, Sequence[str], None] = None): | |
if isinstance(data_prefix, str): | |
data_prefix = dict(img_path=expanduser(data_prefix)) | |
ann_file = expanduser(ann_file) | |
metainfo = self._compat_classes(metainfo, classes) | |
transforms = [] | |
for transform in pipeline: | |
if isinstance(transform, dict): | |
transforms.append(TRANSFORMS.build(transform)) | |
else: | |
transforms.append(transform) | |
super().__init__( | |
ann_file=ann_file, | |
metainfo=metainfo, | |
data_root=data_root, | |
data_prefix=data_prefix, | |
filter_cfg=filter_cfg, | |
indices=indices, | |
serialize_data=serialize_data, | |
pipeline=transforms, | |
test_mode=test_mode, | |
lazy_init=lazy_init, | |
max_refetch=max_refetch) | |
def img_prefix(self): | |
"""The prefix of images.""" | |
return self.data_prefix['img_path'] | |
def CLASSES(self): | |
"""Return all categories names.""" | |
return self._metainfo.get('classes', None) | |
def class_to_idx(self): | |
"""Map mapping class name to class index. | |
Returns: | |
dict: mapping from class name to class index. | |
""" | |
return {cat: i for i, cat in enumerate(self.CLASSES)} | |
def get_gt_labels(self): | |
"""Get all ground-truth labels (categories). | |
Returns: | |
np.ndarray: categories for all images. | |
""" | |
gt_labels = np.array( | |
[self.get_data_info(i)['gt_label'] for i in range(len(self))]) | |
return gt_labels | |
def get_cat_ids(self, idx: int) -> List[int]: | |
"""Get category id by index. | |
Args: | |
idx (int): Index of data. | |
Returns: | |
cat_ids (List[int]): Image category of specified index. | |
""" | |
return [int(self.get_data_info(idx)['gt_label'])] | |
def _compat_classes(self, metainfo, classes): | |
"""Merge the old style ``classes`` arguments to ``metainfo``.""" | |
if isinstance(classes, str): | |
# take it as a file path | |
class_names = mmengine.list_from_file(expanduser(classes)) | |
elif isinstance(classes, (tuple, list)): | |
class_names = classes | |
elif classes is not None: | |
raise ValueError(f'Unsupported type {type(classes)} of classes.') | |
if metainfo is None: | |
metainfo = {} | |
if classes is not None: | |
metainfo = {'classes': tuple(class_names), **metainfo} | |
return metainfo | |
def full_init(self): | |
"""Load annotation file and set ``BaseDataset._fully_initialized`` to | |
True.""" | |
super().full_init() | |
# To support the standard OpenMMLab 2.0 annotation format. Generate | |
# metainfo in internal format from standard metainfo format. | |
if 'categories' in self._metainfo and 'classes' not in self._metainfo: | |
categories = sorted( | |
self._metainfo['categories'], key=lambda x: x['id']) | |
self._metainfo['classes'] = tuple( | |
[cat['category_name'] for cat in categories]) | |
def __repr__(self): | |
"""Print the basic information of the dataset. | |
Returns: | |
str: Formatted string. | |
""" | |
head = 'Dataset ' + self.__class__.__name__ | |
body = [] | |
if self._fully_initialized: | |
body.append(f'Number of samples: \t{self.__len__()}') | |
else: | |
body.append("Haven't been initialized") | |
if self.CLASSES is not None: | |
body.append(f'Number of categories: \t{len(self.CLASSES)}') | |
body.extend(self.extra_repr()) | |
if len(self.pipeline.transforms) > 0: | |
body.append('With transforms:') | |
for t in self.pipeline.transforms: | |
body.append(f' {t}') | |
lines = [head] + [' ' * 4 + line for line in body] | |
return '\n'.join(lines) | |
def extra_repr(self) -> List[str]: | |
"""The extra repr information of the dataset.""" | |
body = [] | |
body.append(f'Annotation file: \t{self.ann_file}') | |
body.append(f'Prefix of images: \t{self.img_prefix}') | |
return body | |