Spaces:
Runtime error
Runtime error
File size: 7,693 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from os import PathLike
from typing import List, Optional, Sequence, Union
import mmengine
import numpy as np
from mmengine.dataset import BaseDataset as _BaseDataset
from .builder import DATASETS
def expanduser(path):
"""Expand ~ and ~user constructions.
If user or $HOME is unknown, do nothing.
"""
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path
@DATASETS.register_module()
class BaseDataset(_BaseDataset):
"""Base dataset for image classification task.
This dataset support annotation file in `OpenMMLab 2.0 style annotation
format`.
.. _OpenMMLab 2.0 style annotation format:
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
several useful methods.
Args:
ann_file (str): Annotation file path.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str | dict): Prefix for training data. Defaults to ''.
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None, which means using all ``data_infos``.
serialize_data (bool): Whether to hold memory using serialized objects,
when enabled, data loader workers can use shared RAM from master
process instead of making a copy. Defaults to True.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
test_mode (bool): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool): Whether to load annotation during instantiation.
In some cases, such as visualization, only the meta information of
the dataset is needed, which is not necessary to load annotation
file. ``Basedataset`` can skip load annotations to save time by set
``lazy_init=False``. Defaults to False.
max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
The maximum extra number of cycles to get a valid image.
Defaults to 1000.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use categories information in ``metainfo`` argument,
annotation file or the class attribute ``METAINFO``.
Defaults to None.
""" # noqa: E501
def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: Sequence = (),
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
classes: Union[str, Sequence[str], None] = None):
if isinstance(data_prefix, str):
data_prefix = dict(img_path=expanduser(data_prefix))
ann_file = expanduser(ann_file)
metainfo = self._compat_classes(metainfo, classes)
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
filter_cfg=filter_cfg,
indices=indices,
serialize_data=serialize_data,
pipeline=pipeline,
test_mode=test_mode,
lazy_init=lazy_init,
max_refetch=max_refetch)
@property
def img_prefix(self):
"""The prefix of images."""
return self.data_prefix['img_path']
@property
def CLASSES(self):
"""Return all categories names."""
return self._metainfo.get('classes', None)
@property
def class_to_idx(self):
"""Map mapping class name to class index.
Returns:
dict: mapping from class name to class index.
"""
return {cat: i for i, cat in enumerate(self.CLASSES)}
def get_gt_labels(self):
"""Get all ground-truth labels (categories).
Returns:
np.ndarray: categories for all images.
"""
gt_labels = np.array(
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
return gt_labels
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image category of specified index.
"""
return [int(self.get_data_info(idx)['gt_label'])]
def _compat_classes(self, metainfo, classes):
"""Merge the old style ``classes`` arguments to ``metainfo``."""
if isinstance(classes, str):
# take it as a file path
class_names = mmengine.list_from_file(expanduser(classes))
elif isinstance(classes, (tuple, list)):
class_names = classes
elif classes is not None:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if metainfo is None:
metainfo = {}
if classes is not None:
metainfo = {'classes': tuple(class_names), **metainfo}
return metainfo
def full_init(self):
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True."""
super().full_init()
# To support the standard OpenMMLab 2.0 annotation format. Generate
# metainfo in internal format from standard metainfo format.
if 'categories' in self._metainfo and 'classes' not in self._metainfo:
categories = sorted(
self._metainfo['categories'], key=lambda x: x['id'])
self._metainfo['classes'] = tuple(
[cat['category_name'] for cat in categories])
def __repr__(self):
"""Print the basic information of the dataset.
Returns:
str: Formatted string.
"""
head = 'Dataset ' + self.__class__.__name__
body = []
if self._fully_initialized:
body.append(f'Number of samples: \t{self.__len__()}')
else:
body.append("Haven't been initialized")
if self.CLASSES is not None:
body.append(f'Number of categories: \t{len(self.CLASSES)}')
else:
body.append('The `CLASSES` meta info is not set.')
body.extend(self.extra_repr())
if len(self.pipeline.transforms) > 0:
body.append('With transforms:')
for t in self.pipeline.transforms:
body.append(f' {t}')
lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = []
body.append(f'Annotation file: \t{self.ann_file}')
body.append(f'Prefix of images: \t{self.img_prefix}')
return body
|