KyanChen commited on
Commit
3e06e1c
1 Parent(s): 0ae0261

Upload 787 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mmdet/__init__.py +27 -0
  2. mmdet/__pycache__/__init__.cpython-310.pyc +0 -0
  3. mmdet/__pycache__/registry.cpython-310.pyc +0 -0
  4. mmdet/__pycache__/version.cpython-310.pyc +0 -0
  5. mmdet/apis/__init__.py +9 -0
  6. mmdet/apis/det_inferencer.py +590 -0
  7. mmdet/apis/inference.py +233 -0
  8. mmdet/datasets/__init__.py +27 -0
  9. mmdet/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  10. mmdet/datasets/__pycache__/base_det_dataset.cpython-310.pyc +0 -0
  11. mmdet/datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
  12. mmdet/datasets/__pycache__/coco.cpython-310.pyc +0 -0
  13. mmdet/datasets/__pycache__/coco_panoptic.cpython-310.pyc +0 -0
  14. mmdet/datasets/__pycache__/crowdhuman.cpython-310.pyc +0 -0
  15. mmdet/datasets/__pycache__/dataset_wrappers.cpython-310.pyc +0 -0
  16. mmdet/datasets/__pycache__/deepfashion.cpython-310.pyc +0 -0
  17. mmdet/datasets/__pycache__/lvis.cpython-310.pyc +0 -0
  18. mmdet/datasets/__pycache__/objects365.cpython-310.pyc +0 -0
  19. mmdet/datasets/__pycache__/openimages.cpython-310.pyc +0 -0
  20. mmdet/datasets/__pycache__/utils.cpython-310.pyc +0 -0
  21. mmdet/datasets/__pycache__/voc.cpython-310.pyc +0 -0
  22. mmdet/datasets/__pycache__/wider_face.cpython-310.pyc +0 -0
  23. mmdet/datasets/__pycache__/xml_style.cpython-310.pyc +0 -0
  24. mmdet/datasets/api_wrappers/__init__.py +4 -0
  25. mmdet/datasets/api_wrappers/__pycache__/__init__.cpython-310.pyc +0 -0
  26. mmdet/datasets/api_wrappers/__pycache__/coco_api.cpython-310.pyc +0 -0
  27. mmdet/datasets/api_wrappers/coco_api.py +137 -0
  28. mmdet/datasets/base_det_dataset.py +120 -0
  29. mmdet/datasets/cityscapes.py +61 -0
  30. mmdet/datasets/coco.py +196 -0
  31. mmdet/datasets/coco_panoptic.py +287 -0
  32. mmdet/datasets/crowdhuman.py +159 -0
  33. mmdet/datasets/dataset_wrappers.py +169 -0
  34. mmdet/datasets/deepfashion.py +19 -0
  35. mmdet/datasets/lvis.py +638 -0
  36. mmdet/datasets/objects365.py +284 -0
  37. mmdet/datasets/openimages.py +484 -0
  38. mmdet/datasets/samplers/__init__.py +9 -0
  39. mmdet/datasets/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
  40. mmdet/datasets/samplers/__pycache__/batch_sampler.cpython-310.pyc +0 -0
  41. mmdet/datasets/samplers/__pycache__/class_aware_sampler.cpython-310.pyc +0 -0
  42. mmdet/datasets/samplers/__pycache__/multi_source_sampler.cpython-310.pyc +0 -0
  43. mmdet/datasets/samplers/batch_sampler.py +68 -0
  44. mmdet/datasets/samplers/class_aware_sampler.py +192 -0
  45. mmdet/datasets/samplers/multi_source_sampler.py +214 -0
  46. mmdet/datasets/transforms/__init__.py +36 -0
  47. mmdet/datasets/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  48. mmdet/datasets/transforms/__pycache__/augment_wrappers.cpython-310.pyc +0 -0
  49. mmdet/datasets/transforms/__pycache__/colorspace.cpython-310.pyc +0 -0
  50. mmdet/datasets/transforms/__pycache__/formatting.cpython-310.pyc +0 -0
mmdet/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import mmcv
3
+ import mmengine
4
+ from mmengine.utils import digit_version
5
+
6
+ from .version import __version__, version_info
7
+
8
+ mmcv_minimum_version = '2.0.0rc4'
9
+ mmcv_maximum_version = '2.1.0'
10
+ mmcv_version = digit_version(mmcv.__version__)
11
+
12
+ mmengine_minimum_version = '0.7.0'
13
+ mmengine_maximum_version = '1.0.0'
14
+ mmengine_version = digit_version(mmengine.__version__)
15
+
16
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
17
+ and mmcv_version < digit_version(mmcv_maximum_version)), \
18
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
19
+ f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
20
+
21
+ assert (mmengine_version >= digit_version(mmengine_minimum_version)
22
+ and mmengine_version < digit_version(mmengine_maximum_version)), \
23
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
24
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
25
+ f'<{mmengine_maximum_version}.'
26
+
27
+ __all__ = ['__version__', 'version_info', 'digit_version']
mmdet/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (817 Bytes). View file
 
mmdet/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
mmdet/__pycache__/version.cpython-310.pyc ADDED
Binary file (803 Bytes). View file
 
mmdet/apis/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .det_inferencer import DetInferencer
3
+ from .inference import (async_inference_detector, inference_detector,
4
+ init_detector)
5
+
6
+ __all__ = [
7
+ 'init_detector', 'async_inference_detector', 'inference_detector',
8
+ 'DetInferencer'
9
+ ]
mmdet/apis/det_inferencer.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import os.path as osp
4
+ import warnings
5
+ from typing import Dict, Iterable, List, Optional, Sequence, Union
6
+
7
+ import mmcv
8
+ import mmengine
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ from mmengine.dataset import Compose
12
+ from mmengine.fileio import (get_file_backend, isdir, join_path,
13
+ list_dir_or_file)
14
+ from mmengine.infer.infer import BaseInferencer, ModelType
15
+ from mmengine.model.utils import revert_sync_batchnorm
16
+ from mmengine.registry import init_default_scope
17
+ from mmengine.runner.checkpoint import _load_checkpoint_to_model
18
+ from mmengine.visualization import Visualizer
19
+ from rich.progress import track
20
+
21
+ from mmdet.evaluation import INSTANCE_OFFSET
22
+ from mmdet.registry import DATASETS
23
+ from mmdet.structures import DetDataSample
24
+ from mmdet.structures.mask import encode_mask_results, mask2bbox
25
+ from mmdet.utils import ConfigType
26
+ from ..evaluation import get_classes
27
+
28
+ try:
29
+ from panopticapi.evaluation import VOID
30
+ from panopticapi.utils import id2rgb
31
+ except ImportError:
32
+ id2rgb = None
33
+ VOID = None
34
+
35
+ InputType = Union[str, np.ndarray]
36
+ InputsType = Union[InputType, Sequence[InputType]]
37
+ PredType = List[DetDataSample]
38
+ ImgType = Union[np.ndarray, Sequence[np.ndarray]]
39
+
40
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
41
+ '.tiff', '.webp')
42
+
43
+
44
+ class DetInferencer(BaseInferencer):
45
+ """Object Detection Inferencer.
46
+
47
+ Args:
48
+ model (str, optional): Path to the config file or the model name
49
+ defined in metafile. For example, it could be
50
+ "rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or
51
+ "configs/rtmdet/rtmdet_s_8xb32-300e_coco.py".
52
+ If model is not specified, user must provide the
53
+ `weights` saved by MMEngine which contains the config string.
54
+ Defaults to None.
55
+ weights (str, optional): Path to the checkpoint. If it is not specified
56
+ and model is a model name of metafile, the weights will be loaded
57
+ from metafile. Defaults to None.
58
+ device (str, optional): Device to run inference. If None, the available
59
+ device will be automatically used. Defaults to None.
60
+ scope (str, optional): The scope of the model. Defaults to mmdet.
61
+ palette (str): Color palette used for visualization. The order of
62
+ priority is palette -> config -> checkpoint. Defaults to 'none'.
63
+ """
64
+
65
+ preprocess_kwargs: set = set()
66
+ forward_kwargs: set = set()
67
+ visualize_kwargs: set = {
68
+ 'return_vis',
69
+ 'show',
70
+ 'wait_time',
71
+ 'draw_pred',
72
+ 'pred_score_thr',
73
+ 'img_out_dir',
74
+ 'no_save_vis',
75
+ }
76
+ postprocess_kwargs: set = {
77
+ 'print_result',
78
+ 'pred_out_dir',
79
+ 'return_datasample',
80
+ 'no_save_pred',
81
+ }
82
+
83
+ def __init__(self,
84
+ model: Optional[Union[ModelType, str]] = None,
85
+ weights: Optional[str] = None,
86
+ device: Optional[str] = None,
87
+ scope: Optional[str] = 'mmdet',
88
+ palette: str = 'none') -> None:
89
+ # A global counter tracking the number of images processed, for
90
+ # naming of the output images
91
+ self.num_visualized_imgs = 0
92
+ self.num_predicted_imgs = 0
93
+ self.palette = palette
94
+ init_default_scope(scope)
95
+ super().__init__(
96
+ model=model, weights=weights, device=device, scope=scope)
97
+ self.model = revert_sync_batchnorm(self.model)
98
+
99
+ def _load_weights_to_model(self, model: nn.Module,
100
+ checkpoint: Optional[dict],
101
+ cfg: Optional[ConfigType]) -> None:
102
+ """Loading model weights and meta information from cfg and checkpoint.
103
+
104
+ Args:
105
+ model (nn.Module): Model to load weights and meta information.
106
+ checkpoint (dict, optional): The loaded checkpoint.
107
+ cfg (Config or ConfigDict, optional): The loaded config.
108
+ """
109
+
110
+ if checkpoint is not None:
111
+ _load_checkpoint_to_model(model, checkpoint)
112
+ checkpoint_meta = checkpoint.get('meta', {})
113
+ # save the dataset_meta in the model for convenience
114
+ if 'dataset_meta' in checkpoint_meta:
115
+ # mmdet 3.x, all keys should be lowercase
116
+ model.dataset_meta = {
117
+ k.lower(): v
118
+ for k, v in checkpoint_meta['dataset_meta'].items()
119
+ }
120
+ elif 'CLASSES' in checkpoint_meta:
121
+ # < mmdet 3.x
122
+ classes = checkpoint_meta['CLASSES']
123
+ model.dataset_meta = {'classes': classes}
124
+ else:
125
+ warnings.warn(
126
+ 'dataset_meta or class names are not saved in the '
127
+ 'checkpoint\'s meta data, use COCO classes by default.')
128
+ model.dataset_meta = {'classes': get_classes('coco')}
129
+ else:
130
+ warnings.warn('Checkpoint is not loaded, and the inference '
131
+ 'result is calculated by the randomly initialized '
132
+ 'model!')
133
+ warnings.warn('weights is None, use COCO classes by default.')
134
+ model.dataset_meta = {'classes': get_classes('coco')}
135
+
136
+ # Priority: args.palette -> config -> checkpoint
137
+ if self.palette != 'none':
138
+ model.dataset_meta['palette'] = self.palette
139
+ else:
140
+ test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset)
141
+ # lazy init. We only need the metainfo.
142
+ test_dataset_cfg['lazy_init'] = True
143
+ metainfo = DATASETS.build(test_dataset_cfg).metainfo
144
+ cfg_palette = metainfo.get('palette', None)
145
+ if cfg_palette is not None:
146
+ model.dataset_meta['palette'] = cfg_palette
147
+ else:
148
+ if 'palette' not in model.dataset_meta:
149
+ warnings.warn(
150
+ 'palette does not exist, random is used by default. '
151
+ 'You can also set the palette to customize.')
152
+ model.dataset_meta['palette'] = 'random'
153
+
154
+ def _init_pipeline(self, cfg: ConfigType) -> Compose:
155
+ """Initialize the test pipeline."""
156
+ pipeline_cfg = cfg.test_dataloader.dataset.pipeline
157
+
158
+ # For inference, the key of ``img_id`` is not used.
159
+ if 'meta_keys' in pipeline_cfg[-1]:
160
+ pipeline_cfg[-1]['meta_keys'] = tuple(
161
+ meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
162
+ if meta_key != 'img_id')
163
+
164
+ load_img_idx = self._get_transform_idx(pipeline_cfg,
165
+ 'LoadImageFromFile')
166
+ if load_img_idx == -1:
167
+ raise ValueError(
168
+ 'LoadImageFromFile is not found in the test pipeline')
169
+ pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
170
+ return Compose(pipeline_cfg)
171
+
172
+ def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
173
+ """Returns the index of the transform in a pipeline.
174
+
175
+ If the transform is not found, returns -1.
176
+ """
177
+ for i, transform in enumerate(pipeline_cfg):
178
+ if transform['type'] == name:
179
+ return i
180
+ return -1
181
+
182
+ def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
183
+ """Initialize visualizers.
184
+
185
+ Args:
186
+ cfg (ConfigType): Config containing the visualizer information.
187
+
188
+ Returns:
189
+ Visualizer or None: Visualizer initialized with config.
190
+ """
191
+ visualizer = super()._init_visualizer(cfg)
192
+ visualizer.dataset_meta = self.model.dataset_meta
193
+ return visualizer
194
+
195
+ def _inputs_to_list(self, inputs: InputsType) -> list:
196
+ """Preprocess the inputs to a list.
197
+
198
+ Preprocess inputs to a list according to its type:
199
+
200
+ - list or tuple: return inputs
201
+ - str:
202
+ - Directory path: return all files in the directory
203
+ - other cases: return a list containing the string. The string
204
+ could be a path to file, a url or other types of string according
205
+ to the task.
206
+
207
+ Args:
208
+ inputs (InputsType): Inputs for the inferencer.
209
+
210
+ Returns:
211
+ list: List of input for the :meth:`preprocess`.
212
+ """
213
+ if isinstance(inputs, str):
214
+ backend = get_file_backend(inputs)
215
+ if hasattr(backend, 'isdir') and isdir(inputs):
216
+ # Backends like HttpsBackend do not implement `isdir`, so only
217
+ # those backends that implement `isdir` could accept the inputs
218
+ # as a directory
219
+ filename_list = list_dir_or_file(
220
+ inputs, list_dir=False, suffix=IMG_EXTENSIONS)
221
+ inputs = [
222
+ join_path(inputs, filename) for filename in filename_list
223
+ ]
224
+
225
+ if not isinstance(inputs, (list, tuple)):
226
+ inputs = [inputs]
227
+
228
+ return list(inputs)
229
+
230
+ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
231
+ """Process the inputs into a model-feedable format.
232
+
233
+ Customize your preprocess by overriding this method. Preprocess should
234
+ return an iterable object, of which each item will be used as the
235
+ input of ``model.test_step``.
236
+
237
+ ``BaseInferencer.preprocess`` will return an iterable chunked data,
238
+ which will be used in __call__ like this:
239
+
240
+ .. code-block:: python
241
+
242
+ def __call__(self, inputs, batch_size=1, **kwargs):
243
+ chunked_data = self.preprocess(inputs, batch_size, **kwargs)
244
+ for batch in chunked_data:
245
+ preds = self.forward(batch, **kwargs)
246
+
247
+ Args:
248
+ inputs (InputsType): Inputs given by user.
249
+ batch_size (int): batch size. Defaults to 1.
250
+
251
+ Yields:
252
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
253
+ """
254
+ chunked_data = self._get_chunk_data(inputs, batch_size)
255
+ yield from map(self.collate_fn, chunked_data)
256
+
257
+ def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
258
+ """Get batch data from inputs.
259
+
260
+ Args:
261
+ inputs (Iterable): An iterable dataset.
262
+ chunk_size (int): Equivalent to batch size.
263
+
264
+ Yields:
265
+ list: batch data.
266
+ """
267
+ inputs_iter = iter(inputs)
268
+ while True:
269
+ try:
270
+ chunk_data = []
271
+ for _ in range(chunk_size):
272
+ inputs_ = next(inputs_iter)
273
+ chunk_data.append((inputs_, self.pipeline(inputs_)))
274
+ yield chunk_data
275
+ except StopIteration:
276
+ if chunk_data:
277
+ yield chunk_data
278
+ break
279
+
280
+ # TODO: Video and Webcam are currently not supported and
281
+ # may consume too much memory if your input folder has a lot of images.
282
+ # We will be optimized later.
283
+ def __call__(self,
284
+ inputs: InputsType,
285
+ batch_size: int = 1,
286
+ return_vis: bool = False,
287
+ show: bool = False,
288
+ wait_time: int = 0,
289
+ no_save_vis: bool = False,
290
+ draw_pred: bool = True,
291
+ pred_score_thr: float = 0.3,
292
+ return_datasample: bool = False,
293
+ print_result: bool = False,
294
+ no_save_pred: bool = True,
295
+ out_dir: str = '',
296
+ **kwargs) -> dict:
297
+ """Call the inferencer.
298
+
299
+ Args:
300
+ inputs (InputsType): Inputs for the inferencer.
301
+ batch_size (int): Inference batch size. Defaults to 1.
302
+ show (bool): Whether to display the visualization results in a
303
+ popup window. Defaults to False.
304
+ wait_time (float): The interval of show (s). Defaults to 0.
305
+ no_save_vis (bool): Whether to force not to save prediction
306
+ vis results. Defaults to False.
307
+ draw_pred (bool): Whether to draw predicted bounding boxes.
308
+ Defaults to True.
309
+ pred_score_thr (float): Minimum score of bboxes to draw.
310
+ Defaults to 0.3.
311
+ return_datasample (bool): Whether to return results as
312
+ :obj:`DetDataSample`. Defaults to False.
313
+ print_result (bool): Whether to print the inference result w/o
314
+ visualization to the console. Defaults to False.
315
+ no_save_pred (bool): Whether to force not to save prediction
316
+ results. Defaults to True.
317
+ out_file: Dir to save the inference results or
318
+ visualization. If left as empty, no file will be saved.
319
+ Defaults to ''.
320
+
321
+ **kwargs: Other keyword arguments passed to :meth:`preprocess`,
322
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
323
+ Each key in kwargs should be in the corresponding set of
324
+ ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
325
+ and ``postprocess_kwargs``.
326
+
327
+ Returns:
328
+ dict: Inference and visualization results.
329
+ """
330
+ (
331
+ preprocess_kwargs,
332
+ forward_kwargs,
333
+ visualize_kwargs,
334
+ postprocess_kwargs,
335
+ ) = self._dispatch_kwargs(**kwargs)
336
+
337
+ ori_inputs = self._inputs_to_list(inputs)
338
+ inputs = self.preprocess(
339
+ ori_inputs, batch_size=batch_size, **preprocess_kwargs)
340
+
341
+ results_dict = {'predictions': [], 'visualization': []}
342
+ for ori_inputs, data in track(inputs, description='Inference'):
343
+ preds = self.forward(data, **forward_kwargs)
344
+ visualization = self.visualize(
345
+ ori_inputs,
346
+ preds,
347
+ return_vis=return_vis,
348
+ show=show,
349
+ wait_time=wait_time,
350
+ draw_pred=draw_pred,
351
+ pred_score_thr=pred_score_thr,
352
+ no_save_vis=no_save_vis,
353
+ img_out_dir=out_dir,
354
+ **visualize_kwargs)
355
+ results = self.postprocess(
356
+ preds,
357
+ visualization,
358
+ return_datasample=return_datasample,
359
+ print_result=print_result,
360
+ no_save_pred=no_save_pred,
361
+ pred_out_dir=out_dir,
362
+ **postprocess_kwargs)
363
+ results_dict['predictions'].extend(results['predictions'])
364
+ if results['visualization'] is not None:
365
+ results_dict['visualization'].extend(results['visualization'])
366
+ return results_dict
367
+
368
+ def visualize(self,
369
+ inputs: InputsType,
370
+ preds: PredType,
371
+ return_vis: bool = False,
372
+ show: bool = False,
373
+ wait_time: int = 0,
374
+ draw_pred: bool = True,
375
+ pred_score_thr: float = 0.3,
376
+ no_save_vis: bool = False,
377
+ img_out_dir: str = '',
378
+ **kwargs) -> Union[List[np.ndarray], None]:
379
+ """Visualize predictions.
380
+
381
+ Args:
382
+ inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
383
+ preds (List[:obj:`DetDataSample`]): Predictions of the model.
384
+ return_vis (bool): Whether to return the visualization result.
385
+ Defaults to False.
386
+ show (bool): Whether to display the image in a popup window.
387
+ Defaults to False.
388
+ wait_time (float): The interval of show (s). Defaults to 0.
389
+ draw_pred (bool): Whether to draw predicted bounding boxes.
390
+ Defaults to True.
391
+ pred_score_thr (float): Minimum score of bboxes to draw.
392
+ Defaults to 0.3.
393
+ no_save_vis (bool): Whether to force not to save prediction
394
+ vis results. Defaults to False.
395
+ img_out_dir (str): Output directory of visualization results.
396
+ If left as empty, no file will be saved. Defaults to ''.
397
+
398
+ Returns:
399
+ List[np.ndarray] or None: Returns visualization results only if
400
+ applicable.
401
+ """
402
+ if no_save_vis is True:
403
+ img_out_dir = ''
404
+
405
+ if not show and img_out_dir == '' and not return_vis:
406
+ return None
407
+
408
+ if self.visualizer is None:
409
+ raise ValueError('Visualization needs the "visualizer" term'
410
+ 'defined in the config, but got None.')
411
+
412
+ results = []
413
+
414
+ for single_input, pred in zip(inputs, preds):
415
+ if isinstance(single_input, str):
416
+ img_bytes = mmengine.fileio.get(single_input)
417
+ img = mmcv.imfrombytes(img_bytes)
418
+ img = img[:, :, ::-1]
419
+ img_name = osp.basename(single_input)
420
+ elif isinstance(single_input, np.ndarray):
421
+ img = single_input.copy()
422
+ img_num = str(self.num_visualized_imgs).zfill(8)
423
+ img_name = f'{img_num}.jpg'
424
+ else:
425
+ raise ValueError('Unsupported input type: '
426
+ f'{type(single_input)}')
427
+
428
+ out_file = osp.join(img_out_dir, 'vis',
429
+ img_name) if img_out_dir != '' else None
430
+
431
+ self.visualizer.add_datasample(
432
+ img_name,
433
+ img,
434
+ pred,
435
+ show=show,
436
+ wait_time=wait_time,
437
+ draw_gt=False,
438
+ draw_pred=draw_pred,
439
+ pred_score_thr=pred_score_thr,
440
+ out_file=out_file,
441
+ )
442
+ results.append(self.visualizer.get_image())
443
+ self.num_visualized_imgs += 1
444
+
445
+ return results
446
+
447
+ def postprocess(
448
+ self,
449
+ preds: PredType,
450
+ visualization: Optional[List[np.ndarray]] = None,
451
+ return_datasample: bool = False,
452
+ print_result: bool = False,
453
+ no_save_pred: bool = False,
454
+ pred_out_dir: str = '',
455
+ **kwargs,
456
+ ) -> Dict:
457
+ """Process the predictions and visualization results from ``forward``
458
+ and ``visualize``.
459
+
460
+ This method should be responsible for the following tasks:
461
+
462
+ 1. Convert datasamples into a json-serializable dict if needed.
463
+ 2. Pack the predictions and visualization results and return them.
464
+ 3. Dump or log the predictions.
465
+
466
+ Args:
467
+ preds (List[:obj:`DetDataSample`]): Predictions of the model.
468
+ visualization (Optional[np.ndarray]): Visualized predictions.
469
+ return_datasample (bool): Whether to use Datasample to store
470
+ inference results. If False, dict will be used.
471
+ print_result (bool): Whether to print the inference result w/o
472
+ visualization to the console. Defaults to False.
473
+ no_save_pred (bool): Whether to force not to save prediction
474
+ results. Defaults to False.
475
+ pred_out_dir: Dir to save the inference results w/o
476
+ visualization. If left as empty, no file will be saved.
477
+ Defaults to ''.
478
+
479
+ Returns:
480
+ dict: Inference and visualization results with key ``predictions``
481
+ and ``visualization``.
482
+
483
+ - ``visualization`` (Any): Returned by :meth:`visualize`.
484
+ - ``predictions`` (dict or DataSample): Returned by
485
+ :meth:`forward` and processed in :meth:`postprocess`.
486
+ If ``return_datasample=False``, it usually should be a
487
+ json-serializable dict containing only basic data elements such
488
+ as strings and numbers.
489
+ """
490
+ if no_save_pred is True:
491
+ pred_out_dir = ''
492
+
493
+ result_dict = {}
494
+ results = preds
495
+ if not return_datasample:
496
+ results = []
497
+ for pred in preds:
498
+ result = self.pred2dict(pred, pred_out_dir)
499
+ results.append(result)
500
+ elif pred_out_dir != '':
501
+ warnings.warn('Currently does not support saving datasample '
502
+ 'when return_datasample is set to True. '
503
+ 'Prediction results are not saved!')
504
+ # Add img to the results after printing and dumping
505
+ result_dict['predictions'] = results
506
+ if print_result:
507
+ print(result_dict)
508
+ result_dict['visualization'] = visualization
509
+ return result_dict
510
+
511
+ # TODO: The data format and fields saved in json need further discussion.
512
+ # Maybe should include model name, timestamp, filename, image info etc.
513
+ def pred2dict(self,
514
+ data_sample: DetDataSample,
515
+ pred_out_dir: str = '') -> Dict:
516
+ """Extract elements necessary to represent a prediction into a
517
+ dictionary.
518
+
519
+ It's better to contain only basic data elements such as strings and
520
+ numbers in order to guarantee it's json-serializable.
521
+
522
+ Args:
523
+ data_sample (:obj:`DetDataSample`): Predictions of the model.
524
+ pred_out_dir: Dir to save the inference results w/o
525
+ visualization. If left as empty, no file will be saved.
526
+ Defaults to ''.
527
+
528
+ Returns:
529
+ dict: Prediction results.
530
+ """
531
+ is_save_pred = True
532
+ if pred_out_dir == '':
533
+ is_save_pred = False
534
+
535
+ if is_save_pred and 'img_path' in data_sample:
536
+ img_path = osp.basename(data_sample.img_path)
537
+ img_path = osp.splitext(img_path)[0]
538
+ out_img_path = osp.join(pred_out_dir, 'preds',
539
+ img_path + '_panoptic_seg.png')
540
+ out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
541
+ elif is_save_pred:
542
+ out_img_path = osp.join(
543
+ pred_out_dir, 'preds',
544
+ f'{self.num_predicted_imgs}_panoptic_seg.png')
545
+ out_json_path = osp.join(pred_out_dir, 'preds',
546
+ f'{self.num_predicted_imgs}.json')
547
+ self.num_predicted_imgs += 1
548
+
549
+ result = {}
550
+ if 'pred_instances' in data_sample:
551
+ masks = data_sample.pred_instances.get('masks')
552
+ pred_instances = data_sample.pred_instances.numpy()
553
+ result = {
554
+ 'bboxes': pred_instances.bboxes.tolist(),
555
+ 'labels': pred_instances.labels.tolist(),
556
+ 'scores': pred_instances.scores.tolist()
557
+ }
558
+ if masks is not None:
559
+ if pred_instances.bboxes.sum() == 0:
560
+ # Fake bbox, such as the SOLO.
561
+ bboxes = mask2bbox(masks.cpu()).numpy().tolist()
562
+ result['bboxes'] = bboxes
563
+ encode_masks = encode_mask_results(pred_instances.masks)
564
+ for encode_mask in encode_masks:
565
+ if isinstance(encode_mask['counts'], bytes):
566
+ encode_mask['counts'] = encode_mask['counts'].decode()
567
+ result['masks'] = encode_masks
568
+
569
+ if 'pred_panoptic_seg' in data_sample:
570
+ if VOID is None:
571
+ raise RuntimeError(
572
+ 'panopticapi is not installed, please install it by: '
573
+ 'pip install git+https://github.com/cocodataset/'
574
+ 'panopticapi.git.')
575
+
576
+ pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
577
+ pan[pan % INSTANCE_OFFSET == len(
578
+ self.model.dataset_meta['classes'])] = VOID
579
+ pan = id2rgb(pan).astype(np.uint8)
580
+
581
+ if is_save_pred:
582
+ mmcv.imwrite(pan[:, :, ::-1], out_img_path)
583
+ result['panoptic_seg_path'] = out_img_path
584
+ else:
585
+ result['panoptic_seg'] = pan
586
+
587
+ if is_save_pred:
588
+ mmengine.dump(result, out_json_path)
589
+
590
+ return result
mmdet/apis/inference.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional, Sequence, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from mmcv.ops import RoIPool
11
+ from mmcv.transforms import Compose
12
+ from mmengine.config import Config
13
+ from mmengine.model.utils import revert_sync_batchnorm
14
+ from mmengine.registry import init_default_scope
15
+ from mmengine.runner import load_checkpoint
16
+
17
+ from mmdet.registry import DATASETS
18
+ from ..evaluation import get_classes
19
+ from ..registry import MODELS
20
+ from ..structures import DetDataSample, SampleList
21
+ from ..utils import get_test_pipeline_cfg
22
+
23
+
24
+ def init_detector(
25
+ config: Union[str, Path, Config],
26
+ checkpoint: Optional[str] = None,
27
+ palette: str = 'none',
28
+ device: str = 'cuda:0',
29
+ cfg_options: Optional[dict] = None,
30
+ ) -> nn.Module:
31
+ """Initialize a detector from config file.
32
+
33
+ Args:
34
+ config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
35
+ :obj:`Path`, or the config object.
36
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
37
+ will not load any weights.
38
+ palette (str): Color palette used for visualization. If palette
39
+ is stored in checkpoint, use checkpoint's palette first, otherwise
40
+ use externally passed palette. Currently, supports 'coco', 'voc',
41
+ 'citys' and 'random'. Defaults to none.
42
+ device (str): The device where the anchors will be put on.
43
+ Defaults to cuda:0.
44
+ cfg_options (dict, optional): Options to override some settings in
45
+ the used config.
46
+
47
+ Returns:
48
+ nn.Module: The constructed detector.
49
+ """
50
+ if isinstance(config, (str, Path)):
51
+ config = Config.fromfile(config)
52
+ elif not isinstance(config, Config):
53
+ raise TypeError('config must be a filename or Config object, '
54
+ f'but got {type(config)}')
55
+ if cfg_options is not None:
56
+ config.merge_from_dict(cfg_options)
57
+ elif 'init_cfg' in config.model.backbone:
58
+ config.model.backbone.init_cfg = None
59
+ init_default_scope(config.get('default_scope', 'mmdet'))
60
+
61
+ model = MODELS.build(config.model)
62
+ model = revert_sync_batchnorm(model)
63
+ if checkpoint is None:
64
+ warnings.simplefilter('once')
65
+ warnings.warn('checkpoint is None, use COCO classes by default.')
66
+ model.dataset_meta = {'classes': get_classes('coco')}
67
+ else:
68
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
69
+ # Weights converted from elsewhere may not have meta fields.
70
+ checkpoint_meta = checkpoint.get('meta', {})
71
+
72
+ # save the dataset_meta in the model for convenience
73
+ if 'dataset_meta' in checkpoint_meta:
74
+ # mmdet 3.x, all keys should be lowercase
75
+ model.dataset_meta = {
76
+ k.lower(): v
77
+ for k, v in checkpoint_meta['dataset_meta'].items()
78
+ }
79
+ elif 'CLASSES' in checkpoint_meta:
80
+ # < mmdet 3.x
81
+ classes = checkpoint_meta['CLASSES']
82
+ model.dataset_meta = {'classes': classes}
83
+ else:
84
+ warnings.simplefilter('once')
85
+ warnings.warn(
86
+ 'dataset_meta or class names are not saved in the '
87
+ 'checkpoint\'s meta data, use COCO classes by default.')
88
+ model.dataset_meta = {'classes': get_classes('coco')}
89
+
90
+ # Priority: args.palette -> config -> checkpoint
91
+ if palette != 'none':
92
+ model.dataset_meta['palette'] = palette
93
+ else:
94
+ test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
95
+ # lazy init. We only need the metainfo.
96
+ test_dataset_cfg['lazy_init'] = True
97
+ metainfo = DATASETS.build(test_dataset_cfg).metainfo
98
+ cfg_palette = metainfo.get('palette', None)
99
+ if cfg_palette is not None:
100
+ model.dataset_meta['palette'] = cfg_palette
101
+ else:
102
+ if 'palette' not in model.dataset_meta:
103
+ warnings.warn(
104
+ 'palette does not exist, random is used by default. '
105
+ 'You can also set the palette to customize.')
106
+ model.dataset_meta['palette'] = 'random'
107
+
108
+ model.cfg = config # save the config in the model for convenience
109
+ model.to(device)
110
+ model.eval()
111
+ return model
112
+
113
+
114
+ ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
115
+
116
+
117
+ def inference_detector(
118
+ model: nn.Module,
119
+ imgs: ImagesType,
120
+ test_pipeline: Optional[Compose] = None
121
+ ) -> Union[DetDataSample, SampleList]:
122
+ """Inference image(s) with the detector.
123
+
124
+ Args:
125
+ model (nn.Module): The loaded detector.
126
+ imgs (str, ndarray, Sequence[str/ndarray]):
127
+ Either image files or loaded images.
128
+ test_pipeline (:obj:`Compose`): Test pipeline.
129
+
130
+ Returns:
131
+ :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
132
+ If imgs is a list or tuple, the same length list type results
133
+ will be returned, otherwise return the detection results directly.
134
+ """
135
+
136
+ if isinstance(imgs, (list, tuple)):
137
+ is_batch = True
138
+ else:
139
+ imgs = [imgs]
140
+ is_batch = False
141
+
142
+ cfg = model.cfg
143
+
144
+ if test_pipeline is None:
145
+ cfg = cfg.copy()
146
+ test_pipeline = get_test_pipeline_cfg(cfg)
147
+ if isinstance(imgs[0], np.ndarray):
148
+ # Calling this method across libraries will result
149
+ # in module unregistered error if not prefixed with mmdet.
150
+ test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
151
+
152
+ test_pipeline = Compose(test_pipeline)
153
+
154
+ if model.data_preprocessor.device.type == 'cpu':
155
+ for m in model.modules():
156
+ assert not isinstance(
157
+ m, RoIPool
158
+ ), 'CPU inference with RoIPool is not supported currently.'
159
+
160
+ result_list = []
161
+ for img in imgs:
162
+ # prepare data
163
+ if isinstance(img, np.ndarray):
164
+ # TODO: remove img_id.
165
+ data_ = dict(img=img, img_id=0)
166
+ else:
167
+ # TODO: remove img_id.
168
+ data_ = dict(img_path=img, img_id=0)
169
+ # build the data pipeline
170
+ data_ = test_pipeline(data_)
171
+
172
+ data_['inputs'] = [data_['inputs']]
173
+ data_['data_samples'] = [data_['data_samples']]
174
+
175
+ # forward the model
176
+ with torch.no_grad():
177
+ results = model.test_step(data_)[0]
178
+
179
+ result_list.append(results)
180
+
181
+ if not is_batch:
182
+ return result_list[0]
183
+ else:
184
+ return result_list
185
+
186
+
187
+ # TODO: Awaiting refactoring
188
+ async def async_inference_detector(model, imgs):
189
+ """Async inference image(s) with the detector.
190
+
191
+ Args:
192
+ model (nn.Module): The loaded detector.
193
+ img (str | ndarray): Either image files or loaded images.
194
+
195
+ Returns:
196
+ Awaitable detection results.
197
+ """
198
+ if not isinstance(imgs, (list, tuple)):
199
+ imgs = [imgs]
200
+
201
+ cfg = model.cfg
202
+
203
+ if isinstance(imgs[0], np.ndarray):
204
+ cfg = cfg.copy()
205
+ # set loading pipeline type
206
+ cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray'
207
+
208
+ # cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
209
+ test_pipeline = Compose(cfg.data.test.pipeline)
210
+
211
+ datas = []
212
+ for img in imgs:
213
+ # prepare data
214
+ if isinstance(img, np.ndarray):
215
+ # directly add img
216
+ data = dict(img=img)
217
+ else:
218
+ # add information into dict
219
+ data = dict(img_info=dict(filename=img), img_prefix=None)
220
+ # build the data pipeline
221
+ data = test_pipeline(data)
222
+ datas.append(data)
223
+
224
+ for m in model.modules():
225
+ assert not isinstance(
226
+ m,
227
+ RoIPool), 'CPU inference with RoIPool is not supported currently.'
228
+
229
+ # We don't restore `torch.is_grad_enabled()` value during concurrent
230
+ # inference since execution can overlap
231
+ torch.set_grad_enabled(False)
232
+ results = await model.aforward_test(data, rescale=True)
233
+ return results
mmdet/datasets/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base_det_dataset import BaseDetDataset
3
+ from .cityscapes import CityscapesDataset
4
+ from .coco import CocoDataset
5
+ from .coco_panoptic import CocoPanopticDataset
6
+ from .crowdhuman import CrowdHumanDataset
7
+ from .dataset_wrappers import MultiImageMixDataset
8
+ from .deepfashion import DeepFashionDataset
9
+ from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
10
+ from .objects365 import Objects365V1Dataset, Objects365V2Dataset
11
+ from .openimages import OpenImagesChallengeDataset, OpenImagesDataset
12
+ from .samplers import (AspectRatioBatchSampler, ClassAwareSampler,
13
+ GroupMultiSourceSampler, MultiSourceSampler)
14
+ from .utils import get_loading_pipeline
15
+ from .voc import VOCDataset
16
+ from .wider_face import WIDERFaceDataset
17
+ from .xml_style import XMLDataset
18
+
19
+ __all__ = [
20
+ 'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 'VOCDataset',
21
+ 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 'LVISV1Dataset',
22
+ 'WIDERFaceDataset', 'get_loading_pipeline', 'CocoPanopticDataset',
23
+ 'MultiImageMixDataset', 'OpenImagesDataset', 'OpenImagesChallengeDataset',
24
+ 'AspectRatioBatchSampler', 'ClassAwareSampler', 'MultiSourceSampler',
25
+ 'GroupMultiSourceSampler', 'BaseDetDataset', 'CrowdHumanDataset',
26
+ 'Objects365V1Dataset', 'Objects365V2Dataset'
27
+ ]
mmdet/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
mmdet/datasets/__pycache__/base_det_dataset.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
mmdet/datasets/__pycache__/cityscapes.cpython-310.pyc ADDED
Binary file (1.95 kB). View file
 
mmdet/datasets/__pycache__/coco.cpython-310.pyc ADDED
Binary file (6.23 kB). View file
 
mmdet/datasets/__pycache__/coco_panoptic.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
mmdet/datasets/__pycache__/crowdhuman.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
mmdet/datasets/__pycache__/dataset_wrappers.cpython-310.pyc ADDED
Binary file (5.72 kB). View file
 
mmdet/datasets/__pycache__/deepfashion.cpython-310.pyc ADDED
Binary file (903 Bytes). View file
 
mmdet/datasets/__pycache__/lvis.cpython-310.pyc ADDED
Binary file (23.5 kB). View file
 
mmdet/datasets/__pycache__/objects365.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
mmdet/datasets/__pycache__/openimages.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
mmdet/datasets/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.81 kB). View file
 
mmdet/datasets/__pycache__/voc.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
mmdet/datasets/__pycache__/wider_face.cpython-310.pyc ADDED
Binary file (2.98 kB). View file
 
mmdet/datasets/__pycache__/xml_style.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
mmdet/datasets/api_wrappers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .coco_api import COCO, COCOeval, COCOPanoptic
3
+
4
+ __all__ = ['COCO', 'COCOeval', 'COCOPanoptic']
mmdet/datasets/api_wrappers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (263 Bytes). View file
 
mmdet/datasets/api_wrappers/__pycache__/coco_api.cpython-310.pyc ADDED
Binary file (4.5 kB). View file
 
mmdet/datasets/api_wrappers/coco_api.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # This file add snake case alias for coco api
3
+
4
+ import warnings
5
+ from collections import defaultdict
6
+ from typing import List, Optional, Union
7
+
8
+ import pycocotools
9
+ from pycocotools.coco import COCO as _COCO
10
+ from pycocotools.cocoeval import COCOeval as _COCOeval
11
+
12
+
13
+ class COCO(_COCO):
14
+ """This class is almost the same as official pycocotools package.
15
+
16
+ It implements some snake case function aliases. So that the COCO class has
17
+ the same interface as LVIS class.
18
+ """
19
+
20
+ def __init__(self, annotation_file=None):
21
+ if getattr(pycocotools, '__version__', '0') >= '12.0.2':
22
+ warnings.warn(
23
+ 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
24
+ UserWarning)
25
+ super().__init__(annotation_file=annotation_file)
26
+ self.img_ann_map = self.imgToAnns
27
+ self.cat_img_map = self.catToImgs
28
+
29
+ def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
30
+ return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
31
+
32
+ def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
33
+ return self.getCatIds(cat_names, sup_names, cat_ids)
34
+
35
+ def get_img_ids(self, img_ids=[], cat_ids=[]):
36
+ return self.getImgIds(img_ids, cat_ids)
37
+
38
+ def load_anns(self, ids):
39
+ return self.loadAnns(ids)
40
+
41
+ def load_cats(self, ids):
42
+ return self.loadCats(ids)
43
+
44
+ def load_imgs(self, ids):
45
+ return self.loadImgs(ids)
46
+
47
+
48
+ # just for the ease of import
49
+ COCOeval = _COCOeval
50
+
51
+
52
+ class COCOPanoptic(COCO):
53
+ """This wrapper is for loading the panoptic style annotation file.
54
+
55
+ The format is shown in the CocoPanopticDataset class.
56
+
57
+ Args:
58
+ annotation_file (str, optional): Path of annotation file.
59
+ Defaults to None.
60
+ """
61
+
62
+ def __init__(self, annotation_file: Optional[str] = None) -> None:
63
+ super(COCOPanoptic, self).__init__(annotation_file)
64
+
65
+ def createIndex(self) -> None:
66
+ """Create index."""
67
+ # create index
68
+ print('creating index...')
69
+ # anns stores 'segment_id -> annotation'
70
+ anns, cats, imgs = {}, {}, {}
71
+ img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
72
+ if 'annotations' in self.dataset:
73
+ for ann in self.dataset['annotations']:
74
+ for seg_ann in ann['segments_info']:
75
+ # to match with instance.json
76
+ seg_ann['image_id'] = ann['image_id']
77
+ img_to_anns[ann['image_id']].append(seg_ann)
78
+ # segment_id is not unique in coco dataset orz...
79
+ # annotations from different images but
80
+ # may have same segment_id
81
+ if seg_ann['id'] in anns.keys():
82
+ anns[seg_ann['id']].append(seg_ann)
83
+ else:
84
+ anns[seg_ann['id']] = [seg_ann]
85
+
86
+ # filter out annotations from other images
87
+ img_to_anns_ = defaultdict(list)
88
+ for k, v in img_to_anns.items():
89
+ img_to_anns_[k] = [x for x in v if x['image_id'] == k]
90
+ img_to_anns = img_to_anns_
91
+
92
+ if 'images' in self.dataset:
93
+ for img_info in self.dataset['images']:
94
+ img_info['segm_file'] = img_info['file_name'].replace(
95
+ 'jpg', 'png')
96
+ imgs[img_info['id']] = img_info
97
+
98
+ if 'categories' in self.dataset:
99
+ for cat in self.dataset['categories']:
100
+ cats[cat['id']] = cat
101
+
102
+ if 'annotations' in self.dataset and 'categories' in self.dataset:
103
+ for ann in self.dataset['annotations']:
104
+ for seg_ann in ann['segments_info']:
105
+ cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
106
+
107
+ print('index created!')
108
+
109
+ self.anns = anns
110
+ self.imgToAnns = img_to_anns
111
+ self.catToImgs = cat_to_imgs
112
+ self.imgs = imgs
113
+ self.cats = cats
114
+
115
+ def load_anns(self,
116
+ ids: Union[List[int], int] = []) -> Optional[List[dict]]:
117
+ """Load anns with the specified ids.
118
+
119
+ ``self.anns`` is a list of annotation lists instead of a
120
+ list of annotations.
121
+
122
+ Args:
123
+ ids (Union[List[int], int]): Integer ids specifying anns.
124
+
125
+ Returns:
126
+ anns (List[dict], optional): Loaded ann objects.
127
+ """
128
+ anns = []
129
+
130
+ if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
131
+ # self.anns is a list of annotation lists instead of
132
+ # a list of annotations
133
+ for id in ids:
134
+ anns += self.anns[id]
135
+ return anns
136
+ elif type(ids) == int:
137
+ return self.anns[ids]
mmdet/datasets/base_det_dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from typing import List, Optional
4
+
5
+ from mmengine.dataset import BaseDataset
6
+ from mmengine.fileio import load
7
+ from mmengine.utils import is_abs
8
+
9
+ from ..registry import DATASETS
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class BaseDetDataset(BaseDataset):
14
+ """Base dataset for detection.
15
+
16
+ Args:
17
+ proposal_file (str, optional): Proposals file path. Defaults to None.
18
+ file_client_args (dict): Arguments to instantiate the
19
+ corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
20
+ backend_args (dict, optional): Arguments to instantiate the
21
+ corresponding backend. Defaults to None.
22
+ """
23
+
24
+ def __init__(self,
25
+ *args,
26
+ seg_map_suffix: str = '.png',
27
+ proposal_file: Optional[str] = None,
28
+ file_client_args: dict = None,
29
+ backend_args: dict = None,
30
+ **kwargs) -> None:
31
+ self.seg_map_suffix = seg_map_suffix
32
+ self.proposal_file = proposal_file
33
+ self.backend_args = backend_args
34
+ if file_client_args is not None:
35
+ raise RuntimeError(
36
+ 'The `file_client_args` is deprecated, '
37
+ 'please use `backend_args` instead, please refer to'
38
+ 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
39
+ )
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def full_init(self) -> None:
43
+ """Load annotation file and set ``BaseDataset._fully_initialized`` to
44
+ True.
45
+
46
+ If ``lazy_init=False``, ``full_init`` will be called during the
47
+ instantiation and ``self._fully_initialized`` will be set to True. If
48
+ ``obj._fully_initialized=False``, the class method decorated by
49
+ ``force_full_init`` will call ``full_init`` automatically.
50
+
51
+ Several steps to initialize annotation:
52
+
53
+ - load_data_list: Load annotations from annotation file.
54
+ - load_proposals: Load proposals from proposal file, if
55
+ `self.proposal_file` is not None.
56
+ - filter data information: Filter annotations according to
57
+ filter_cfg.
58
+ - slice_data: Slice dataset according to ``self._indices``
59
+ - serialize_data: Serialize ``self.data_list`` if
60
+ ``self.serialize_data`` is True.
61
+ """
62
+ if self._fully_initialized:
63
+ return
64
+ # load data information
65
+ self.data_list = self.load_data_list()
66
+ # get proposals from file
67
+ if self.proposal_file is not None:
68
+ self.load_proposals()
69
+ # filter illegal data, such as data that has no annotations.
70
+ self.data_list = self.filter_data()
71
+
72
+ # Get subset data according to indices.
73
+ if self._indices is not None:
74
+ self.data_list = self._get_unserialized_subset(self._indices)
75
+
76
+ # serialize data_list
77
+ if self.serialize_data:
78
+ self.data_bytes, self.data_address = self._serialize_data()
79
+
80
+ self._fully_initialized = True
81
+
82
+ def load_proposals(self) -> None:
83
+ """Load proposals from proposals file.
84
+
85
+ The `proposals_list` should be a dict[img_path: proposals]
86
+ with the same length as `data_list`. And the `proposals` should be
87
+ a `dict` or :obj:`InstanceData` usually contains following keys.
88
+
89
+ - bboxes (np.ndarry): Has a shape (num_instances, 4),
90
+ the last dimension 4 arrange as (x1, y1, x2, y2).
91
+ - scores (np.ndarry): Classification scores, has a shape
92
+ (num_instance, ).
93
+ """
94
+ # TODO: Add Unit Test after fully support Dump-Proposal Metric
95
+ if not is_abs(self.proposal_file):
96
+ self.proposal_file = osp.join(self.data_root, self.proposal_file)
97
+ proposals_list = load(
98
+ self.proposal_file, backend_args=self.backend_args)
99
+ assert len(self.data_list) == len(proposals_list)
100
+ for data_info in self.data_list:
101
+ img_path = data_info['img_path']
102
+ # `file_name` is the key to obtain the proposals from the
103
+ # `proposals_list`.
104
+ file_name = osp.join(
105
+ osp.split(osp.split(img_path)[0])[-1],
106
+ osp.split(img_path)[-1])
107
+ proposals = proposals_list[file_name]
108
+ data_info['proposals'] = proposals
109
+
110
+ def get_cat_ids(self, idx: int) -> List[int]:
111
+ """Get COCO category ids by index.
112
+
113
+ Args:
114
+ idx (int): Index of data.
115
+
116
+ Returns:
117
+ List[int]: All categories in the image of specified index.
118
+ """
119
+ instances = self.get_data_info(idx)['instances']
120
+ return [instance['bbox_label'] for instance in instances]
mmdet/datasets/cityscapes.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
3
+ # and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
4
+
5
+ from typing import List
6
+
7
+ from mmdet.registry import DATASETS
8
+ from .coco import CocoDataset
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class CityscapesDataset(CocoDataset):
13
+ """Dataset for Cityscapes."""
14
+
15
+ METAINFO = {
16
+ 'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train',
17
+ 'motorcycle', 'bicycle'),
18
+ 'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
19
+ (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
20
+ }
21
+
22
+ def filter_data(self) -> List[dict]:
23
+ """Filter annotations according to filter_cfg.
24
+
25
+ Returns:
26
+ List[dict]: Filtered results.
27
+ """
28
+ if self.test_mode:
29
+ return self.data_list
30
+
31
+ if self.filter_cfg is None:
32
+ return self.data_list
33
+
34
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
35
+ min_size = self.filter_cfg.get('min_size', 0)
36
+
37
+ # obtain images that contain annotation
38
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
39
+ # obtain images that contain annotations of the required categories
40
+ ids_in_cat = set()
41
+ for i, class_id in enumerate(self.cat_ids):
42
+ ids_in_cat |= set(self.cat_img_map[class_id])
43
+ # merge the image id sets of the two conditions and use the merged set
44
+ # to filter out images if self.filter_empty_gt=True
45
+ ids_in_cat &= ids_with_ann
46
+
47
+ valid_data_infos = []
48
+ for i, data_info in enumerate(self.data_list):
49
+ img_id = data_info['img_id']
50
+ width = data_info['width']
51
+ height = data_info['height']
52
+ all_is_crowd = all([
53
+ instance['ignore_flag'] == 1
54
+ for instance in data_info['instances']
55
+ ])
56
+ if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
57
+ continue
58
+ if min(width, height) >= min_size:
59
+ valid_data_infos.append(data_info)
60
+
61
+ return valid_data_infos
mmdet/datasets/coco.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import os.path as osp
4
+ from typing import List, Union
5
+
6
+ from mmengine.fileio import get_local_path
7
+
8
+ from mmdet.registry import DATASETS
9
+ from .api_wrappers import COCO
10
+ from .base_det_dataset import BaseDetDataset
11
+
12
+
13
+ @DATASETS.register_module()
14
+ class CocoDataset(BaseDetDataset):
15
+ """Dataset for COCO."""
16
+
17
+ METAINFO = {
18
+ 'classes':
19
+ ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
20
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
21
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
22
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
23
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
24
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
25
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
26
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
27
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
28
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
29
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
30
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
31
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
32
+ # palette is a list of color tuples, which is used for visualization.
33
+ 'palette':
34
+ [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
35
+ (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
36
+ (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
37
+ (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
38
+ (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
39
+ (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
40
+ (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
41
+ (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
42
+ (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
43
+ (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
44
+ (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
45
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
46
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
47
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
48
+ (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
49
+ (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
50
+ (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
51
+ (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
52
+ (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
53
+ (246, 0, 122), (191, 162, 208)]
54
+ }
55
+ COCOAPI = COCO
56
+ # ann_id is unique in coco dataset.
57
+ ANN_ID_UNIQUE = True
58
+
59
+ def load_data_list(self) -> List[dict]:
60
+ """Load annotations from an annotation file named as ``self.ann_file``
61
+
62
+ Returns:
63
+ List[dict]: A list of annotation.
64
+ """ # noqa: E501
65
+ with get_local_path(
66
+ self.ann_file, backend_args=self.backend_args) as local_path:
67
+ self.coco = self.COCOAPI(local_path)
68
+ # The order of returned `cat_ids` will not
69
+ # change with the order of the `classes`
70
+ self.cat_ids = self.coco.get_cat_ids(
71
+ cat_names=self.metainfo['classes'])
72
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
73
+ self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
74
+
75
+ img_ids = self.coco.get_img_ids()
76
+ data_list = []
77
+ total_ann_ids = []
78
+ for img_id in img_ids:
79
+ raw_img_info = self.coco.load_imgs([img_id])[0]
80
+ raw_img_info['img_id'] = img_id
81
+
82
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
83
+ raw_ann_info = self.coco.load_anns(ann_ids)
84
+ total_ann_ids.extend(ann_ids)
85
+
86
+ parsed_data_info = self.parse_data_info({
87
+ 'raw_ann_info':
88
+ raw_ann_info,
89
+ 'raw_img_info':
90
+ raw_img_info
91
+ })
92
+ data_list.append(parsed_data_info)
93
+ if self.ANN_ID_UNIQUE:
94
+ assert len(set(total_ann_ids)) == len(
95
+ total_ann_ids
96
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
97
+
98
+ del self.coco
99
+
100
+ return data_list
101
+
102
+ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
103
+ """Parse raw annotation to target format.
104
+
105
+ Args:
106
+ raw_data_info (dict): Raw data information load from ``ann_file``
107
+
108
+ Returns:
109
+ Union[dict, List[dict]]: Parsed annotation.
110
+ """
111
+ img_info = raw_data_info['raw_img_info']
112
+ ann_info = raw_data_info['raw_ann_info']
113
+
114
+ data_info = {}
115
+
116
+ # TODO: need to change data_prefix['img'] to data_prefix['img_path']
117
+ img_path = osp.join(self.data_prefix['img_path'], img_info['file_name'])
118
+ if self.data_prefix.get('seg_path', None):
119
+ seg_map_path = osp.join(
120
+ self.data_prefix['seg_path'],
121
+ img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
122
+ else:
123
+ seg_map_path = None
124
+ data_info['img_path'] = img_path
125
+ data_info['img_id'] = img_info['img_id']
126
+ data_info['seg_map_path'] = seg_map_path
127
+ data_info['height'] = img_info['height']
128
+ data_info['width'] = img_info['width']
129
+
130
+ instances = []
131
+ for i, ann in enumerate(ann_info):
132
+ instance = {}
133
+
134
+ if ann.get('ignore', False):
135
+ continue
136
+ x1, y1, w, h = ann['bbox']
137
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
138
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
139
+ if inter_w * inter_h == 0:
140
+ continue
141
+ if ann['area'] <= 0 or w < 1 or h < 1:
142
+ continue
143
+ if ann['category_id'] not in self.cat_ids:
144
+ continue
145
+ bbox = [x1, y1, x1 + w, y1 + h]
146
+
147
+ if ann.get('iscrowd', False):
148
+ instance['ignore_flag'] = 1
149
+ else:
150
+ instance['ignore_flag'] = 0
151
+ instance['bbox'] = bbox
152
+ instance['bbox_label'] = self.cat2label[ann['category_id']]
153
+
154
+ if ann.get('segmentation', None):
155
+ instance['mask'] = ann['segmentation']
156
+
157
+ instances.append(instance)
158
+ data_info['instances'] = instances
159
+ return data_info
160
+
161
+ def filter_data(self) -> List[dict]:
162
+ """Filter annotations according to filter_cfg.
163
+
164
+ Returns:
165
+ List[dict]: Filtered results.
166
+ """
167
+ if self.test_mode:
168
+ return self.data_list
169
+
170
+ if self.filter_cfg is None:
171
+ return self.data_list
172
+
173
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
174
+ min_size = self.filter_cfg.get('min_size', 0)
175
+
176
+ # obtain images that contain annotation
177
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
178
+ # obtain images that contain annotations of the required categories
179
+ ids_in_cat = set()
180
+ for i, class_id in enumerate(self.cat_ids):
181
+ ids_in_cat |= set(self.cat_img_map[class_id])
182
+ # merge the image id sets of the two conditions and use the merged set
183
+ # to filter out images if self.filter_empty_gt=True
184
+ ids_in_cat &= ids_with_ann
185
+
186
+ valid_data_infos = []
187
+ for i, data_info in enumerate(self.data_list):
188
+ img_id = data_info['img_id']
189
+ width = data_info['width']
190
+ height = data_info['height']
191
+ if filter_empty_gt and img_id not in ids_in_cat:
192
+ continue
193
+ if min(width, height) >= min_size:
194
+ valid_data_infos.append(data_info)
195
+
196
+ return valid_data_infos
mmdet/datasets/coco_panoptic.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from typing import Callable, List, Optional, Sequence, Union
4
+
5
+ from mmdet.registry import DATASETS
6
+ from .api_wrappers import COCOPanoptic
7
+ from .coco import CocoDataset
8
+
9
+
10
+ @DATASETS.register_module()
11
+ class CocoPanopticDataset(CocoDataset):
12
+ """Coco dataset for Panoptic segmentation.
13
+
14
+ The annotation format is shown as follows. The `ann` field is optional
15
+ for testing.
16
+
17
+ .. code-block:: none
18
+
19
+ [
20
+ {
21
+ 'filename': f'{image_id:012}.png',
22
+ 'image_id':9
23
+ 'segments_info':
24
+ [
25
+ {
26
+ 'id': 8345037, (segment_id in panoptic png,
27
+ convert from rgb)
28
+ 'category_id': 51,
29
+ 'iscrowd': 0,
30
+ 'bbox': (x1, y1, w, h),
31
+ 'area': 24315
32
+ },
33
+ ...
34
+ ]
35
+ },
36
+ ...
37
+ ]
38
+
39
+ Args:
40
+ ann_file (str): Annotation file path. Defaults to ''.
41
+ metainfo (dict, optional): Meta information for dataset, such as class
42
+ information. Defaults to None.
43
+ data_root (str, optional): The root directory for ``data_prefix`` and
44
+ ``ann_file``. Defaults to None.
45
+ data_prefix (dict, optional): Prefix for training data. Defaults to
46
+ ``dict(img=None, ann=None, seg=None)``. The prefix ``seg`` which is
47
+ for panoptic segmentation map must be not None.
48
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
49
+ indices (int or Sequence[int], optional): Support using first few
50
+ data in annotation file to facilitate training/testing on a smaller
51
+ dataset. Defaults to None which means using all ``data_infos``.
52
+ serialize_data (bool, optional): Whether to hold memory using
53
+ serialized objects, when enabled, data loader workers can use
54
+ shared RAM from master process instead of making a copy. Defaults
55
+ to True.
56
+ pipeline (list, optional): Processing pipeline. Defaults to [].
57
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
58
+ Defaults to False.
59
+ lazy_init (bool, optional): Whether to load annotation during
60
+ instantiation. In some cases, such as visualization, only the meta
61
+ information of the dataset is needed, which is not necessary to
62
+ load annotation file. ``Basedataset`` can skip load annotations to
63
+ save time by set ``lazy_init=False``. Defaults to False.
64
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
65
+ None img. The maximum extra number of cycles to get a valid
66
+ image. Defaults to 1000.
67
+ """
68
+
69
+ METAINFO = {
70
+ 'classes':
71
+ ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
72
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
73
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
74
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
75
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
76
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
77
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
78
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
79
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
80
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
81
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
82
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
83
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
84
+ 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
85
+ 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
86
+ 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
87
+ 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
88
+ 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
89
+ 'wall-wood', 'water-other', 'window-blind', 'window-other',
90
+ 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
91
+ 'cabinet-merged', 'table-merged', 'floor-other-merged',
92
+ 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
93
+ 'paper-merged', 'food-other-merged', 'building-other-merged',
94
+ 'rock-merged', 'wall-other-merged', 'rug-merged'),
95
+ 'thing_classes':
96
+ ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
97
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
98
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
99
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
100
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
101
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
102
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
103
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
104
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
105
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
106
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
107
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
108
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
109
+ 'stuff_classes':
110
+ ('banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain',
111
+ 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house',
112
+ 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
113
+ 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
114
+ 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
115
+ 'wall-wood', 'water-other', 'window-blind', 'window-other',
116
+ 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
117
+ 'cabinet-merged', 'table-merged', 'floor-other-merged',
118
+ 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
119
+ 'paper-merged', 'food-other-merged', 'building-other-merged',
120
+ 'rock-merged', 'wall-other-merged', 'rug-merged'),
121
+ 'palette':
122
+ [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
123
+ (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
124
+ (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
125
+ (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
126
+ (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
127
+ (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
128
+ (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
129
+ (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
130
+ (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
131
+ (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
132
+ (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
133
+ (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
134
+ (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
135
+ (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
136
+ (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
137
+ (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
138
+ (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
139
+ (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
140
+ (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
141
+ (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
142
+ (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
143
+ (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
144
+ (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
145
+ (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
146
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
147
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
148
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
149
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
150
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
151
+ (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
152
+ (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
153
+ (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
154
+ (0, 114, 143), (102, 102, 156), (250, 141, 255)]
155
+ }
156
+ COCOAPI = COCOPanoptic
157
+ # ann_id is not unique in coco panoptic dataset.
158
+ ANN_ID_UNIQUE = False
159
+
160
+ def __init__(self,
161
+ ann_file: str = '',
162
+ metainfo: Optional[dict] = None,
163
+ data_root: Optional[str] = None,
164
+ data_prefix: dict = dict(img=None, ann=None, seg=None),
165
+ filter_cfg: Optional[dict] = None,
166
+ indices: Optional[Union[int, Sequence[int]]] = None,
167
+ serialize_data: bool = True,
168
+ pipeline: List[Union[dict, Callable]] = [],
169
+ test_mode: bool = False,
170
+ lazy_init: bool = False,
171
+ max_refetch: int = 1000,
172
+ backend_args: dict = None,
173
+ **kwargs) -> None:
174
+ super().__init__(
175
+ ann_file=ann_file,
176
+ metainfo=metainfo,
177
+ data_root=data_root,
178
+ data_prefix=data_prefix,
179
+ filter_cfg=filter_cfg,
180
+ indices=indices,
181
+ serialize_data=serialize_data,
182
+ pipeline=pipeline,
183
+ test_mode=test_mode,
184
+ lazy_init=lazy_init,
185
+ max_refetch=max_refetch,
186
+ backend_args=backend_args,
187
+ **kwargs)
188
+
189
+ def parse_data_info(self, raw_data_info: dict) -> dict:
190
+ """Parse raw annotation to target format.
191
+
192
+ Args:
193
+ raw_data_info (dict): Raw data information load from ``ann_file``.
194
+
195
+ Returns:
196
+ dict: Parsed annotation.
197
+ """
198
+ img_info = raw_data_info['raw_img_info']
199
+ ann_info = raw_data_info['raw_ann_info']
200
+ # filter out unmatched annotations which have
201
+ # same segment_id but belong to other image
202
+ ann_info = [
203
+ ann for ann in ann_info if ann['image_id'] == img_info['img_id']
204
+ ]
205
+ data_info = {}
206
+
207
+ img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
208
+ if self.data_prefix.get('seg', None):
209
+ seg_map_path = osp.join(
210
+ self.data_prefix['seg'],
211
+ img_info['file_name'].replace('jpg', 'png'))
212
+ else:
213
+ seg_map_path = None
214
+ data_info['img_path'] = img_path
215
+ data_info['img_id'] = img_info['img_id']
216
+ data_info['seg_map_path'] = seg_map_path
217
+ data_info['height'] = img_info['height']
218
+ data_info['width'] = img_info['width']
219
+
220
+ instances = []
221
+ segments_info = []
222
+ for ann in ann_info:
223
+ instance = {}
224
+ x1, y1, w, h = ann['bbox']
225
+ if ann['area'] <= 0 or w < 1 or h < 1:
226
+ continue
227
+ bbox = [x1, y1, x1 + w, y1 + h]
228
+ category_id = ann['category_id']
229
+ contiguous_cat_id = self.cat2label[category_id]
230
+
231
+ is_thing = self.coco.load_cats(ids=category_id)[0]['isthing']
232
+ if is_thing:
233
+ is_crowd = ann.get('iscrowd', False)
234
+ instance['bbox'] = bbox
235
+ instance['bbox_label'] = contiguous_cat_id
236
+ if not is_crowd:
237
+ instance['ignore_flag'] = 0
238
+ else:
239
+ instance['ignore_flag'] = 1
240
+ is_thing = False
241
+
242
+ segment_info = {
243
+ 'id': ann['id'],
244
+ 'category': contiguous_cat_id,
245
+ 'is_thing': is_thing
246
+ }
247
+ segments_info.append(segment_info)
248
+ if len(instance) > 0 and is_thing:
249
+ instances.append(instance)
250
+ data_info['instances'] = instances
251
+ data_info['segments_info'] = segments_info
252
+ return data_info
253
+
254
+ def filter_data(self) -> List[dict]:
255
+ """Filter images too small or without ground truth.
256
+
257
+ Returns:
258
+ List[dict]: ``self.data_list`` after filtering.
259
+ """
260
+ if self.test_mode:
261
+ return self.data_list
262
+
263
+ if self.filter_cfg is None:
264
+ return self.data_list
265
+
266
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
267
+ min_size = self.filter_cfg.get('min_size', 0)
268
+
269
+ ids_with_ann = set()
270
+ # check whether images have legal thing annotations.
271
+ for data_info in self.data_list:
272
+ for segment_info in data_info['segments_info']:
273
+ if not segment_info['is_thing']:
274
+ continue
275
+ ids_with_ann.add(data_info['img_id'])
276
+
277
+ valid_data_list = []
278
+ for data_info in self.data_list:
279
+ img_id = data_info['img_id']
280
+ width = data_info['width']
281
+ height = data_info['height']
282
+ if filter_empty_gt and img_id not in ids_with_ann:
283
+ continue
284
+ if min(width, height) >= min_size:
285
+ valid_data_list.append(data_info)
286
+
287
+ return valid_data_list
mmdet/datasets/crowdhuman.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ import logging
4
+ import os.path as osp
5
+ import warnings
6
+ from typing import List, Union
7
+
8
+ import mmcv
9
+ from mmengine.dist import get_rank
10
+ from mmengine.fileio import dump, get, get_text, load
11
+ from mmengine.logging import print_log
12
+ from mmengine.utils import ProgressBar
13
+
14
+ from mmdet.registry import DATASETS
15
+ from .base_det_dataset import BaseDetDataset
16
+
17
+
18
+ @DATASETS.register_module()
19
+ class CrowdHumanDataset(BaseDetDataset):
20
+ r"""Dataset for CrowdHuman.
21
+
22
+ Args:
23
+ data_root (str): The root directory for
24
+ ``data_prefix`` and ``ann_file``.
25
+ ann_file (str): Annotation file path.
26
+ extra_ann_file (str | optional):The path of extra image metas
27
+ for CrowdHuman. It can be created by CrowdHumanDataset
28
+ automatically or by tools/misc/get_crowdhuman_id_hw.py
29
+ manually. Defaults to None.
30
+ """
31
+
32
+ METAINFO = {
33
+ 'classes': ('person', ),
34
+ # palette is a list of color tuples, which is used for visualization.
35
+ 'palette': [(220, 20, 60)]
36
+ }
37
+
38
+ def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs):
39
+ # extra_ann_file record the size of each image. This file is
40
+ # automatically created when you first load the CrowdHuman
41
+ # dataset by mmdet.
42
+ if extra_ann_file is not None:
43
+ self.extra_ann_exist = True
44
+ self.extra_anns = load(extra_ann_file)
45
+ else:
46
+ ann_file_name = osp.basename(ann_file)
47
+ if 'train' in ann_file_name:
48
+ self.extra_ann_file = osp.join(data_root, 'id_hw_train.json')
49
+ elif 'val' in ann_file_name:
50
+ self.extra_ann_file = osp.join(data_root, 'id_hw_val.json')
51
+ self.extra_ann_exist = False
52
+ if not osp.isfile(self.extra_ann_file):
53
+ print_log(
54
+ 'extra_ann_file does not exist, prepare to collect '
55
+ 'image height and width...',
56
+ level=logging.INFO)
57
+ self.extra_anns = {}
58
+ else:
59
+ self.extra_ann_exist = True
60
+ self.extra_anns = load(self.extra_ann_file)
61
+ super().__init__(data_root=data_root, ann_file=ann_file, **kwargs)
62
+
63
+ def load_data_list(self) -> List[dict]:
64
+ """Load annotations from an annotation file named as ``self.ann_file``
65
+
66
+ Returns:
67
+ List[dict]: A list of annotation.
68
+ """ # noqa: E501
69
+ anno_strs = get_text(
70
+ self.ann_file, backend_args=self.backend_args).strip().split('\n')
71
+ print_log('loading CrowdHuman annotation...', level=logging.INFO)
72
+ data_list = []
73
+ prog_bar = ProgressBar(len(anno_strs))
74
+ for i, anno_str in enumerate(anno_strs):
75
+ anno_dict = json.loads(anno_str)
76
+ parsed_data_info = self.parse_data_info(anno_dict)
77
+ data_list.append(parsed_data_info)
78
+ prog_bar.update()
79
+ if not self.extra_ann_exist and get_rank() == 0:
80
+ # TODO: support file client
81
+ try:
82
+ dump(self.extra_anns, self.extra_ann_file, file_format='json')
83
+ except: # noqa
84
+ warnings.warn(
85
+ 'Cache files can not be saved automatically! To speed up'
86
+ 'loading the dataset, please manually generate the cache'
87
+ ' file by file tools/misc/get_crowdhuman_id_hw.py')
88
+
89
+ print_log(
90
+ f'\nsave extra_ann_file in {self.data_root}',
91
+ level=logging.INFO)
92
+
93
+ del self.extra_anns
94
+ print_log('\nDone', level=logging.INFO)
95
+ return data_list
96
+
97
+ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
98
+ """Parse raw annotation to target format.
99
+
100
+ Args:
101
+ raw_data_info (dict): Raw data information load from ``ann_file``
102
+
103
+ Returns:
104
+ Union[dict, List[dict]]: Parsed annotation.
105
+ """
106
+ data_info = {}
107
+ img_path = osp.join(self.data_prefix['img'],
108
+ f"{raw_data_info['ID']}.jpg")
109
+ data_info['img_path'] = img_path
110
+ data_info['img_id'] = raw_data_info['ID']
111
+
112
+ if not self.extra_ann_exist:
113
+ img_bytes = get(img_path, backend_args=self.backend_args)
114
+ img = mmcv.imfrombytes(img_bytes, backend='cv2')
115
+ data_info['height'], data_info['width'] = img.shape[:2]
116
+ self.extra_anns[raw_data_info['ID']] = img.shape[:2]
117
+ del img, img_bytes
118
+ else:
119
+ data_info['height'], data_info['width'] = self.extra_anns[
120
+ raw_data_info['ID']]
121
+
122
+ instances = []
123
+ for i, ann in enumerate(raw_data_info['gtboxes']):
124
+ instance = {}
125
+ if ann['tag'] not in self.metainfo['classes']:
126
+ instance['bbox_label'] = -1
127
+ instance['ignore_flag'] = 1
128
+ else:
129
+ instance['bbox_label'] = self.metainfo['classes'].index(
130
+ ann['tag'])
131
+ instance['ignore_flag'] = 0
132
+ if 'extra' in ann:
133
+ if 'ignore' in ann['extra']:
134
+ if ann['extra']['ignore'] != 0:
135
+ instance['bbox_label'] = -1
136
+ instance['ignore_flag'] = 1
137
+
138
+ x1, y1, w, h = ann['fbox']
139
+ bbox = [x1, y1, x1 + w, y1 + h]
140
+ instance['bbox'] = bbox
141
+
142
+ # Record the full bbox(fbox), head bbox(hbox) and visible
143
+ # bbox(vbox) as additional information. If you need to use
144
+ # this information, you just need to design the pipeline
145
+ # instead of overriding the CrowdHumanDataset.
146
+ instance['fbox'] = bbox
147
+ hbox = ann['hbox']
148
+ instance['hbox'] = [
149
+ hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3]
150
+ ]
151
+ vbox = ann['vbox']
152
+ instance['vbox'] = [
153
+ vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3]
154
+ ]
155
+
156
+ instances.append(instance)
157
+
158
+ data_info['instances'] = instances
159
+ return data_info
mmdet/datasets/dataset_wrappers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import collections
3
+ import copy
4
+ from typing import Sequence, Union
5
+
6
+ from mmengine.dataset import BaseDataset, force_full_init
7
+
8
+ from mmdet.registry import DATASETS, TRANSFORMS
9
+
10
+
11
+ @DATASETS.register_module()
12
+ class MultiImageMixDataset:
13
+ """A wrapper of multiple images mixed dataset.
14
+
15
+ Suitable for training on multiple images mixed data augmentation like
16
+ mosaic and mixup. For the augmentation pipeline of mixed image data,
17
+ the `get_indexes` method needs to be provided to obtain the image
18
+ indexes, and you can set `skip_flags` to change the pipeline running
19
+ process. At the same time, we provide the `dynamic_scale` parameter
20
+ to dynamically change the output image size.
21
+
22
+ Args:
23
+ dataset (:obj:`CustomDataset`): The dataset to be mixed.
24
+ pipeline (Sequence[dict]): Sequence of transform object or
25
+ config dict to be composed.
26
+ dynamic_scale (tuple[int], optional): The image scale can be changed
27
+ dynamically. Default to None. It is deprecated.
28
+ skip_type_keys (list[str], optional): Sequence of type string to
29
+ be skip pipeline. Default to None.
30
+ max_refetch (int): The maximum number of retry iterations for getting
31
+ valid results from the pipeline. If the number of iterations is
32
+ greater than `max_refetch`, but results is still None, then the
33
+ iteration is terminated and raise the error. Default: 15.
34
+ """
35
+
36
+ def __init__(self,
37
+ dataset: Union[BaseDataset, dict],
38
+ pipeline: Sequence[str],
39
+ skip_type_keys: Union[Sequence[str], None] = None,
40
+ max_refetch: int = 15,
41
+ lazy_init: bool = False) -> None:
42
+ assert isinstance(pipeline, collections.abc.Sequence)
43
+ if skip_type_keys is not None:
44
+ assert all([
45
+ isinstance(skip_type_key, str)
46
+ for skip_type_key in skip_type_keys
47
+ ])
48
+ self._skip_type_keys = skip_type_keys
49
+
50
+ self.pipeline = []
51
+ self.pipeline_types = []
52
+ for transform in pipeline:
53
+ if isinstance(transform, dict):
54
+ self.pipeline_types.append(transform['type'])
55
+ transform = TRANSFORMS.build(transform)
56
+ self.pipeline.append(transform)
57
+ else:
58
+ raise TypeError('pipeline must be a dict')
59
+
60
+ self.dataset: BaseDataset
61
+ if isinstance(dataset, dict):
62
+ self.dataset = DATASETS.build(dataset)
63
+ elif isinstance(dataset, BaseDataset):
64
+ self.dataset = dataset
65
+ else:
66
+ raise TypeError(
67
+ 'elements in datasets sequence should be config or '
68
+ f'`BaseDataset` instance, but got {type(dataset)}')
69
+
70
+ self._metainfo = self.dataset.metainfo
71
+ if hasattr(self.dataset, 'flag'):
72
+ self.flag = self.dataset.flag
73
+ self.num_samples = len(self.dataset)
74
+ self.max_refetch = max_refetch
75
+
76
+ self._fully_initialized = False
77
+ if not lazy_init:
78
+ self.full_init()
79
+
80
+ @property
81
+ def metainfo(self) -> dict:
82
+ """Get the meta information of the multi-image-mixed dataset.
83
+
84
+ Returns:
85
+ dict: The meta information of multi-image-mixed dataset.
86
+ """
87
+ return copy.deepcopy(self._metainfo)
88
+
89
+ def full_init(self):
90
+ """Loop to ``full_init`` each dataset."""
91
+ if self._fully_initialized:
92
+ return
93
+
94
+ self.dataset.full_init()
95
+ self._ori_len = len(self.dataset)
96
+ self._fully_initialized = True
97
+
98
+ @force_full_init
99
+ def get_data_info(self, idx: int) -> dict:
100
+ """Get annotation by index.
101
+
102
+ Args:
103
+ idx (int): Global index of ``ConcatDataset``.
104
+
105
+ Returns:
106
+ dict: The idx-th annotation of the datasets.
107
+ """
108
+ return self.dataset.get_data_info(idx)
109
+
110
+ @force_full_init
111
+ def __len__(self):
112
+ return self.num_samples
113
+
114
+ def __getitem__(self, idx):
115
+ results = copy.deepcopy(self.dataset[idx])
116
+ for (transform, transform_type) in zip(self.pipeline,
117
+ self.pipeline_types):
118
+ if self._skip_type_keys is not None and \
119
+ transform_type in self._skip_type_keys:
120
+ continue
121
+
122
+ if hasattr(transform, 'get_indexes'):
123
+ for i in range(self.max_refetch):
124
+ # Make sure the results passed the loading pipeline
125
+ # of the original dataset is not None.
126
+ indexes = transform.get_indexes(self.dataset)
127
+ if not isinstance(indexes, collections.abc.Sequence):
128
+ indexes = [indexes]
129
+ mix_results = [
130
+ copy.deepcopy(self.dataset[index]) for index in indexes
131
+ ]
132
+ if None not in mix_results:
133
+ results['mix_results'] = mix_results
134
+ break
135
+ else:
136
+ raise RuntimeError(
137
+ 'The loading pipeline of the original dataset'
138
+ ' always return None. Please check the correctness '
139
+ 'of the dataset and its pipeline.')
140
+
141
+ for i in range(self.max_refetch):
142
+ # To confirm the results passed the training pipeline
143
+ # of the wrapper is not None.
144
+ updated_results = transform(copy.deepcopy(results))
145
+ if updated_results is not None:
146
+ results = updated_results
147
+ break
148
+ else:
149
+ raise RuntimeError(
150
+ 'The training pipeline of the dataset wrapper'
151
+ ' always return None.Please check the correctness '
152
+ 'of the dataset and its pipeline.')
153
+
154
+ if 'mix_results' in results:
155
+ results.pop('mix_results')
156
+
157
+ return results
158
+
159
+ def update_skip_type_keys(self, skip_type_keys):
160
+ """Update skip_type_keys. It is called by an external hook.
161
+
162
+ Args:
163
+ skip_type_keys (list[str], optional): Sequence of type
164
+ string to be skip pipeline.
165
+ """
166
+ assert all([
167
+ isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
168
+ ])
169
+ self._skip_type_keys = skip_type_keys
mmdet/datasets/deepfashion.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmdet.registry import DATASETS
3
+ from .coco import CocoDataset
4
+
5
+
6
+ @DATASETS.register_module()
7
+ class DeepFashionDataset(CocoDataset):
8
+ """Dataset for DeepFashion."""
9
+
10
+ METAINFO = {
11
+ 'classes': ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants',
12
+ 'bag', 'neckwear', 'headwear', 'eyeglass', 'belt',
13
+ 'footwear', 'hair', 'skin', 'face'),
14
+ # palette is a list of color tuples, which is used for visualization.
15
+ 'palette': [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64),
16
+ (0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96),
17
+ (128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192),
18
+ (128, 0, 96), (128, 0, 192), (0, 32, 192)]
19
+ }
mmdet/datasets/lvis.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+ from typing import List
5
+
6
+ from mmengine.fileio import get_local_path
7
+
8
+ from mmdet.registry import DATASETS
9
+ from .coco import CocoDataset
10
+
11
+
12
+ @DATASETS.register_module()
13
+ class LVISV05Dataset(CocoDataset):
14
+ """LVIS v0.5 dataset for detection."""
15
+
16
+ METAINFO = {
17
+ 'classes':
18
+ ('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
19
+ 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
20
+ 'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
21
+ 'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
22
+ 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
23
+ 'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
24
+ 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
25
+ 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
26
+ 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
27
+ 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
28
+ 'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
29
+ 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
30
+ 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball',
31
+ 'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
32
+ 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle',
33
+ 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle',
34
+ 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder',
35
+ 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage',
36
+ 'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)',
37
+ 'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer',
38
+ 'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard',
39
+ 'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt',
40
+ 'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet',
41
+ 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener',
42
+ 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie',
43
+ 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
44
+ 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
45
+ 'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
46
+ 'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
47
+ 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
48
+ 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
49
+ 'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
50
+ 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
51
+ 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
52
+ 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
53
+ 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
54
+ 'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
55
+ 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
56
+ 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
57
+ 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
58
+ 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
59
+ 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
60
+ 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
61
+ 'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
62
+ 'celery', 'cellular_telephone', 'chain_mail', 'chair',
63
+ 'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook',
64
+ 'checkerboard', 'cherry', 'chessboard',
65
+ 'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire',
66
+ 'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware',
67
+ 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
68
+ 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
69
+ 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
70
+ 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
71
+ 'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard',
72
+ 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag',
73
+ 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut',
74
+ 'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
75
+ 'coin', 'colander', 'coleslaw', 'coloring_material',
76
+ 'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard',
77
+ 'concrete_mixer', 'cone', 'control', 'convertible_(automobile)',
78
+ 'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil',
79
+ 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew',
80
+ 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
81
+ 'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
82
+ 'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
83
+ 'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
84
+ 'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
85
+ 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
86
+ 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
87
+ 'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
88
+ 'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
89
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
90
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
91
+ 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
92
+ 'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
93
+ 'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
94
+ 'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
95
+ 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
96
+ 'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
97
+ 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
98
+ 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
99
+ 'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
100
+ 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
101
+ 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
102
+ 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
103
+ 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
104
+ 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)',
105
+ 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose',
106
+ 'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl',
107
+ 'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo',
108
+ 'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
109
+ 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
110
+ 'folding_chair', 'food_processor', 'football_(American)',
111
+ 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
112
+ 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
113
+ 'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag',
114
+ 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle',
115
+ 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
116
+ 'gift_wrap', 'ginger', 'giraffe', 'cincture',
117
+ 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
118
+ 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
119
+ 'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
120
+ 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
121
+ 'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
122
+ 'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
123
+ 'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
124
+ 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
125
+ 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
126
+ 'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
127
+ 'headband', 'headboard', 'headlight', 'headscarf', 'headset',
128
+ 'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
129
+ 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
130
+ 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
131
+ 'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
132
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
133
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
134
+ 'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
135
+ 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
136
+ 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
137
+ 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
138
+ 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
139
+ 'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
140
+ 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
141
+ 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
142
+ 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
143
+ 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
144
+ 'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
145
+ 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
146
+ 'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
147
+ 'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
148
+ 'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
149
+ 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
150
+ 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
151
+ 'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
152
+ 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
153
+ 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
154
+ 'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
155
+ 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
156
+ 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
157
+ 'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
158
+ 'mound_(baseball)', 'mouse_(animal_rodent)',
159
+ 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
160
+ 'music_stool', 'musical_instrument', 'nailfile', 'nameplate',
161
+ 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest',
162
+ 'newsstand', 'nightshirt', 'nosebag_(for_animals)',
163
+ 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
164
+ 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
165
+ 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano',
166
+ 'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet',
167
+ 'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush',
168
+ 'painting', 'pajamas', 'palette', 'pan_(for_cooking)',
169
+ 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya',
170
+ 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
171
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
172
+ 'parchment', 'parka', 'parking_meter', 'parrot',
173
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
174
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
175
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
176
+ 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
177
+ 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
178
+ 'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
179
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
180
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
181
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
182
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
183
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
184
+ 'plate', 'platter', 'playing_card', 'playpen', 'pliers',
185
+ 'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
186
+ 'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
187
+ 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
188
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
189
+ 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
190
+ 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
191
+ 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
192
+ 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
193
+ 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
194
+ 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
195
+ 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
196
+ 'recliner', 'record_player', 'red_cabbage', 'reflector',
197
+ 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
198
+ 'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
199
+ 'Rollerblade', 'rolling_pin', 'root_beer',
200
+ 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
201
+ 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
202
+ 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
203
+ 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
204
+ 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
205
+ 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
206
+ 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
207
+ 'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
208
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
209
+ 'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
210
+ 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
211
+ 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
212
+ 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag',
213
+ 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag',
214
+ 'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)',
215
+ 'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski',
216
+ 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag',
217
+ 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake',
218
+ 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock',
219
+ 'soda_fountain', 'carbonated_water', 'sofa', 'softball',
220
+ 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
221
+ 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
222
+ 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge',
223
+ 'spoon', 'sportswear', 'spotlight', 'squirrel',
224
+ 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)',
225
+ 'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)',
226
+ 'steering_wheel', 'stencil', 'stepladder', 'step_stool',
227
+ 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup',
228
+ 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove',
229
+ 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
230
+ 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
231
+ 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
232
+ 'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
233
+ 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
234
+ 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
235
+ 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
236
+ 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
237
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
238
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
239
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
240
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
241
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
242
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
243
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
244
+ 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
245
+ 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
246
+ 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
247
+ 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
248
+ 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
249
+ 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
250
+ 'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
251
+ 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
252
+ 'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
253
+ 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
254
+ 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
255
+ 'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
256
+ 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
257
+ 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
258
+ 'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
259
+ 'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
260
+ 'water_heater', 'water_jug', 'water_gun', 'water_scooter',
261
+ 'water_ski', 'water_tower', 'watering_can', 'watermelon',
262
+ 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit',
263
+ 'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick',
264
+ 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
265
+ 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
266
+ 'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf',
267
+ 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht',
268
+ 'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
269
+ 'palette':
270
+ None
271
+ }
272
+
273
+ def load_data_list(self) -> List[dict]:
274
+ """Load annotations from an annotation file named as ``self.ann_file``
275
+
276
+ Returns:
277
+ List[dict]: A list of annotation.
278
+ """ # noqa: E501
279
+ try:
280
+ import lvis
281
+ if getattr(lvis, '__version__', '0') >= '10.5.3':
282
+ warnings.warn(
283
+ 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
284
+ UserWarning)
285
+ from lvis import LVIS
286
+ except ImportError:
287
+ raise ImportError(
288
+ 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
289
+ )
290
+ with get_local_path(
291
+ self.ann_file, backend_args=self.backend_args) as local_path:
292
+ self.lvis = LVIS(local_path)
293
+ self.cat_ids = self.lvis.get_cat_ids()
294
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
295
+ self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
296
+
297
+ img_ids = self.lvis.get_img_ids()
298
+ data_list = []
299
+ total_ann_ids = []
300
+ for img_id in img_ids:
301
+ raw_img_info = self.lvis.load_imgs([img_id])[0]
302
+ raw_img_info['img_id'] = img_id
303
+ if raw_img_info['file_name'].startswith('COCO'):
304
+ # Convert form the COCO 2014 file naming convention of
305
+ # COCO_[train/val/test]2014_000000000000.jpg to the 2017
306
+ # naming convention of 000000000000.jpg
307
+ # (LVIS v1 will fix this naming issue)
308
+ raw_img_info['file_name'] = raw_img_info['file_name'][-16:]
309
+ ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
310
+ raw_ann_info = self.lvis.load_anns(ann_ids)
311
+ total_ann_ids.extend(ann_ids)
312
+
313
+ parsed_data_info = self.parse_data_info({
314
+ 'raw_ann_info':
315
+ raw_ann_info,
316
+ 'raw_img_info':
317
+ raw_img_info
318
+ })
319
+ data_list.append(parsed_data_info)
320
+ if self.ANN_ID_UNIQUE:
321
+ assert len(set(total_ann_ids)) == len(
322
+ total_ann_ids
323
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
324
+
325
+ del self.lvis
326
+
327
+ return data_list
328
+
329
+
330
+ LVISDataset = LVISV05Dataset
331
+ DATASETS.register_module(name='LVISDataset', module=LVISDataset)
332
+
333
+
334
+ @DATASETS.register_module()
335
+ class LVISV1Dataset(LVISDataset):
336
+ """LVIS v1 dataset for detection."""
337
+
338
+ METAINFO = {
339
+ 'classes':
340
+ ('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
341
+ 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
342
+ 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
343
+ 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
344
+ 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
345
+ 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
346
+ 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
347
+ 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
348
+ 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
349
+ 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
350
+ 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
351
+ 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
352
+ 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
353
+ 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
354
+ 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
355
+ 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
356
+ 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
357
+ 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
358
+ 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
359
+ 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
360
+ 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
361
+ 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
362
+ 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
363
+ 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
364
+ 'bottle_opener', 'bouquet', 'bow_(weapon)',
365
+ 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
366
+ 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
367
+ 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
368
+ 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
369
+ 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
370
+ 'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
371
+ 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
372
+ 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
373
+ 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
374
+ 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
375
+ 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
376
+ 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
377
+ 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
378
+ 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
379
+ 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
380
+ 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
381
+ 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
382
+ 'cash_register', 'casserole', 'cassette', 'cast', 'cat',
383
+ 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
384
+ 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
385
+ 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
386
+ 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
387
+ 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
388
+ 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
389
+ 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
390
+ 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
391
+ 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
392
+ 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
393
+ 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
394
+ 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
395
+ 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
396
+ 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
397
+ 'coin', 'colander', 'coleslaw', 'coloring_material',
398
+ 'combination_lock', 'pacifier', 'comic_book', 'compass',
399
+ 'computer_keyboard', 'condiment', 'cone', 'control',
400
+ 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
401
+ 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
402
+ 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
403
+ 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
404
+ 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
405
+ 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
406
+ 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
407
+ 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
408
+ 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
409
+ 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
410
+ 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
411
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
412
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
413
+ 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
414
+ 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
415
+ 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
416
+ 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
417
+ 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
418
+ 'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
419
+ 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
420
+ 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
421
+ 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
422
+ 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
423
+ 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
424
+ 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
425
+ 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
426
+ 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
427
+ 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
428
+ 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
429
+ 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
430
+ 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
431
+ 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
432
+ 'food_processor', 'football_(American)', 'football_helmet',
433
+ 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
434
+ 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
435
+ 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
436
+ 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
437
+ 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
438
+ 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
439
+ 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
440
+ 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
441
+ 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
442
+ 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
443
+ 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
444
+ 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
445
+ 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
446
+ 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
447
+ 'headband', 'headboard', 'headlight', 'headscarf', 'headset',
448
+ 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
449
+ 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
450
+ 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
451
+ 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
452
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
453
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
454
+ 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
455
+ 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
456
+ 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
457
+ 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
458
+ 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
459
+ 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
460
+ 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
461
+ 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
462
+ 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
463
+ 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
464
+ 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
465
+ 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
466
+ 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
467
+ 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
468
+ 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
469
+ 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
470
+ 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
471
+ 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
472
+ 'meatball', 'medicine', 'melon', 'microphone', 'microscope',
473
+ 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
474
+ 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
475
+ 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
476
+ 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
477
+ 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
478
+ 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
479
+ 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
480
+ 'nest', 'newspaper', 'newsstand', 'nightshirt',
481
+ 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
482
+ 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
483
+ 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
484
+ 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
485
+ 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
486
+ 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
487
+ 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
488
+ 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
489
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
490
+ 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
491
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
492
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
493
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
494
+ 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
495
+ 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
496
+ 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
497
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
498
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
499
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
500
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
501
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
502
+ 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
503
+ 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
504
+ 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
505
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
506
+ 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
507
+ 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
508
+ 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
509
+ 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
510
+ 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
511
+ 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
512
+ 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
513
+ 'recliner', 'record_player', 'reflector', 'remote_control',
514
+ 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
515
+ 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
516
+ 'rolling_pin', 'root_beer', 'router_(computer_equipment)',
517
+ 'rubber_band', 'runner_(carpet)', 'plastic_bag',
518
+ 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
519
+ 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
520
+ 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
521
+ 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
522
+ 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
523
+ 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
524
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
525
+ 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
526
+ 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
527
+ 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
528
+ 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
529
+ 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
530
+ 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
531
+ 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
532
+ 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
533
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
534
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
535
+ 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
536
+ 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
537
+ 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
538
+ 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
539
+ 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
540
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
541
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
542
+ 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
543
+ 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
544
+ 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
545
+ 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
546
+ 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
547
+ 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
548
+ 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
549
+ 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
550
+ 'tambourine', 'army_tank', 'tank_(storage_vessel)',
551
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
552
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
553
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
554
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
555
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
556
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
557
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
558
+ 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
559
+ 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
560
+ 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
561
+ 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
562
+ 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
563
+ 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
564
+ 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
565
+ 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
566
+ 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
567
+ 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
568
+ 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
569
+ 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
570
+ 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
571
+ 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
572
+ 'washbasin', 'automatic_washer', 'watch', 'water_bottle',
573
+ 'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
574
+ 'water_gun', 'water_scooter', 'water_ski', 'water_tower',
575
+ 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
576
+ 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
577
+ 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
578
+ 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
579
+ 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
580
+ 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
581
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
582
+ 'palette':
583
+ None
584
+ }
585
+
586
+ def load_data_list(self) -> List[dict]:
587
+ """Load annotations from an annotation file named as ``self.ann_file``
588
+
589
+ Returns:
590
+ List[dict]: A list of annotation.
591
+ """ # noqa: E501
592
+ try:
593
+ import lvis
594
+ if getattr(lvis, '__version__', '0') >= '10.5.3':
595
+ warnings.warn(
596
+ 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
597
+ UserWarning)
598
+ from lvis import LVIS
599
+ except ImportError:
600
+ raise ImportError(
601
+ 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
602
+ )
603
+ with get_local_path(
604
+ self.ann_file, backend_args=self.backend_args) as local_path:
605
+ self.lvis = LVIS(local_path)
606
+ self.cat_ids = self.lvis.get_cat_ids()
607
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
608
+ self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
609
+
610
+ img_ids = self.lvis.get_img_ids()
611
+ data_list = []
612
+ total_ann_ids = []
613
+ for img_id in img_ids:
614
+ raw_img_info = self.lvis.load_imgs([img_id])[0]
615
+ raw_img_info['img_id'] = img_id
616
+ # coco_url is used in LVISv1 instead of file_name
617
+ # e.g. http://images.cocodataset.org/train2017/000000391895.jpg
618
+ # train/val split in specified in url
619
+ raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
620
+ 'http://images.cocodataset.org/', '')
621
+ ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
622
+ raw_ann_info = self.lvis.load_anns(ann_ids)
623
+ total_ann_ids.extend(ann_ids)
624
+ parsed_data_info = self.parse_data_info({
625
+ 'raw_ann_info':
626
+ raw_ann_info,
627
+ 'raw_img_info':
628
+ raw_img_info
629
+ })
630
+ data_list.append(parsed_data_info)
631
+ if self.ANN_ID_UNIQUE:
632
+ assert len(set(total_ann_ids)) == len(
633
+ total_ann_ids
634
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
635
+
636
+ del self.lvis
637
+
638
+ return data_list
mmdet/datasets/objects365.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import os.path as osp
4
+ from typing import List
5
+
6
+ from mmengine.fileio import get_local_path
7
+
8
+ from mmdet.registry import DATASETS
9
+ from .api_wrappers import COCO
10
+ from .coco import CocoDataset
11
+
12
+ # images exist in annotations but not in image folder.
13
+ objv2_ignore_list = [
14
+ osp.join('patch16', 'objects365_v2_00908726.jpg'),
15
+ osp.join('patch6', 'objects365_v1_00320532.jpg'),
16
+ osp.join('patch6', 'objects365_v1_00320534.jpg'),
17
+ ]
18
+
19
+
20
+ @DATASETS.register_module()
21
+ class Objects365V1Dataset(CocoDataset):
22
+ """Objects365 v1 dataset for detection."""
23
+
24
+ METAINFO = {
25
+ 'classes':
26
+ ('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
27
+ 'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
28
+ 'handbag', 'street lights', 'book', 'plate', 'helmet',
29
+ 'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet',
30
+ 'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots',
31
+ 'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker',
32
+ 'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet',
33
+ 'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table',
34
+ 'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil',
35
+ 'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet',
36
+ 'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink',
37
+ 'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake',
38
+ 'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign',
39
+ 'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock',
40
+ 'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep',
41
+ 'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey',
42
+ 'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon',
43
+ 'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck',
44
+ 'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer',
45
+ 'trolley', 'oven', 'remote', 'baseball glove', 'paper towel',
46
+ 'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent',
47
+ 'shampoo/shower gel', 'head phone', 'lantern', 'donut',
48
+ 'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite',
49
+ 'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli',
50
+ 'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave',
51
+ 'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver',
52
+ 'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car',
53
+ 'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter',
54
+ 'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
55
+ 'cutting/chopping board', 'tennis racket', 'candy',
56
+ 'skating and skiing shoes', 'scissors', 'folder', 'baseball',
57
+ 'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
58
+ 'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
59
+ 'american football', 'basketball', 'potato', 'paint brush', 'printer',
60
+ 'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
61
+ 'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
62
+ 'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
63
+ 'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
64
+ 'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
65
+ 'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
66
+ 'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
67
+ 'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
68
+ 'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
69
+ 'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
70
+ 'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
71
+ 'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
72
+ 'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
73
+ 'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
74
+ 'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
75
+ 'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
76
+ 'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
77
+ 'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste',
78
+ 'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate',
79
+ 'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill',
80
+ 'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car',
81
+ 'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier',
82
+ 'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey',
83
+ 'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot',
84
+ 'flute', 'measuring cup', 'shark', 'steak', 'poker card',
85
+ 'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab',
86
+ 'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage',
87
+ 'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser',
88
+ 'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle',
89
+ 'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board',
90
+ 'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard',
91
+ 'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron',
92
+ 'flashlight'),
93
+ 'palette':
94
+ None
95
+ }
96
+
97
+ COCOAPI = COCO
98
+ # ann_id is unique in coco dataset.
99
+ ANN_ID_UNIQUE = True
100
+
101
+ def load_data_list(self) -> List[dict]:
102
+ """Load annotations from an annotation file named as ``self.ann_file``
103
+
104
+ Returns:
105
+ List[dict]: A list of annotation.
106
+ """ # noqa: E501
107
+ with get_local_path(
108
+ self.ann_file, backend_args=self.backend_args) as local_path:
109
+ self.coco = self.COCOAPI(local_path)
110
+
111
+ # 'categories' list in objects365_train.json and objects365_val.json
112
+ # is inconsistent, need sort list(or dict) before get cat_ids.
113
+ cats = self.coco.cats
114
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
115
+ self.coco.cats = sorted_cats
116
+ categories = self.coco.dataset['categories']
117
+ sorted_categories = sorted(categories, key=lambda i: i['id'])
118
+ self.coco.dataset['categories'] = sorted_categories
119
+ # The order of returned `cat_ids` will not
120
+ # change with the order of the `classes`
121
+ self.cat_ids = self.coco.get_cat_ids(
122
+ cat_names=self.metainfo['classes'])
123
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
124
+ self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
125
+
126
+ img_ids = self.coco.get_img_ids()
127
+ data_list = []
128
+ total_ann_ids = []
129
+ for img_id in img_ids:
130
+ raw_img_info = self.coco.load_imgs([img_id])[0]
131
+ raw_img_info['img_id'] = img_id
132
+
133
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
134
+ raw_ann_info = self.coco.load_anns(ann_ids)
135
+ total_ann_ids.extend(ann_ids)
136
+
137
+ parsed_data_info = self.parse_data_info({
138
+ 'raw_ann_info':
139
+ raw_ann_info,
140
+ 'raw_img_info':
141
+ raw_img_info
142
+ })
143
+ data_list.append(parsed_data_info)
144
+ if self.ANN_ID_UNIQUE:
145
+ assert len(set(total_ann_ids)) == len(
146
+ total_ann_ids
147
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
148
+
149
+ del self.coco
150
+
151
+ return data_list
152
+
153
+
154
+ @DATASETS.register_module()
155
+ class Objects365V2Dataset(CocoDataset):
156
+ """Objects365 v2 dataset for detection."""
157
+ METAINFO = {
158
+ 'classes':
159
+ ('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
160
+ 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
161
+ 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
162
+ 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
163
+ 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
164
+ 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
165
+ 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
166
+ 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
167
+ 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
168
+ 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle',
169
+ 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned',
170
+ 'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel',
171
+ 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed',
172
+ 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple',
173
+ 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck',
174
+ 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock',
175
+ 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
176
+ 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
177
+ 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
178
+ 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
179
+ 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
180
+ 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
181
+ 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
182
+ 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
183
+ 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
184
+ 'Elephant', 'Skateboard', 'Surfboard', 'Gun',
185
+ 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
186
+ 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
187
+ 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
188
+ 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
189
+ 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
190
+ 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
191
+ 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle',
192
+ 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
193
+ 'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club',
194
+ 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear',
195
+ 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong',
196
+ 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask',
197
+ 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide',
198
+ 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee',
199
+ 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
200
+ 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
201
+ 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
202
+ 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
203
+ 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
204
+ 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
205
+ 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
206
+ 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
207
+ 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
208
+ 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
209
+ 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
210
+ 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
211
+ 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
212
+ 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
213
+ 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
214
+ 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
215
+ 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker',
216
+ 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
217
+ 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin',
218
+ 'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill',
219
+ 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi',
220
+ 'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case',
221
+ 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
222
+ 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
223
+ 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
224
+ 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
225
+ 'Table Tennis '),
226
+ 'palette':
227
+ None
228
+ }
229
+
230
+ COCOAPI = COCO
231
+ # ann_id is unique in coco dataset.
232
+ ANN_ID_UNIQUE = True
233
+
234
+ def load_data_list(self) -> List[dict]:
235
+ """Load annotations from an annotation file named as ``self.ann_file``
236
+
237
+ Returns:
238
+ List[dict]: A list of annotation.
239
+ """ # noqa: E501
240
+ with get_local_path(
241
+ self.ann_file, backend_args=self.backend_args) as local_path:
242
+ self.coco = self.COCOAPI(local_path)
243
+ # The order of returned `cat_ids` will not
244
+ # change with the order of the `classes`
245
+ self.cat_ids = self.coco.get_cat_ids(
246
+ cat_names=self.metainfo['classes'])
247
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
248
+ self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
249
+
250
+ img_ids = self.coco.get_img_ids()
251
+ data_list = []
252
+ total_ann_ids = []
253
+ for img_id in img_ids:
254
+ raw_img_info = self.coco.load_imgs([img_id])[0]
255
+ raw_img_info['img_id'] = img_id
256
+
257
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
258
+ raw_ann_info = self.coco.load_anns(ann_ids)
259
+ total_ann_ids.extend(ann_ids)
260
+
261
+ # file_name should be `patchX/xxx.jpg`
262
+ file_name = osp.join(
263
+ osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
264
+ osp.split(raw_img_info['file_name'])[-1])
265
+
266
+ if file_name in objv2_ignore_list:
267
+ continue
268
+
269
+ raw_img_info['file_name'] = file_name
270
+ parsed_data_info = self.parse_data_info({
271
+ 'raw_ann_info':
272
+ raw_ann_info,
273
+ 'raw_img_info':
274
+ raw_img_info
275
+ })
276
+ data_list.append(parsed_data_info)
277
+ if self.ANN_ID_UNIQUE:
278
+ assert len(set(total_ann_ids)) == len(
279
+ total_ann_ids
280
+ ), f"Annotation ids in '{self.ann_file}' are not unique!"
281
+
282
+ del self.coco
283
+
284
+ return data_list
mmdet/datasets/openimages.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import csv
3
+ import os.path as osp
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Optional
6
+
7
+ import numpy as np
8
+ from mmengine.fileio import get_local_path, load
9
+ from mmengine.utils import is_abs
10
+
11
+ from mmdet.registry import DATASETS
12
+ from .base_det_dataset import BaseDetDataset
13
+
14
+
15
+ @DATASETS.register_module()
16
+ class OpenImagesDataset(BaseDetDataset):
17
+ """Open Images dataset for detection.
18
+
19
+ Args:
20
+ ann_file (str): Annotation file path.
21
+ label_file (str): File path of the label description file that
22
+ maps the classes names in MID format to their short
23
+ descriptions.
24
+ meta_file (str): File path to get image metas.
25
+ hierarchy_file (str): The file path of the class hierarchy.
26
+ image_level_ann_file (str): Human-verified image level annotation,
27
+ which is used in evaluation.
28
+ backend_args (dict, optional): Arguments to instantiate the
29
+ corresponding backend. Defaults to None.
30
+ """
31
+
32
+ METAINFO: dict = dict(dataset_type='oid_v6')
33
+
34
+ def __init__(self,
35
+ label_file: str,
36
+ meta_file: str,
37
+ hierarchy_file: str,
38
+ image_level_ann_file: Optional[str] = None,
39
+ **kwargs) -> None:
40
+ self.label_file = label_file
41
+ self.meta_file = meta_file
42
+ self.hierarchy_file = hierarchy_file
43
+ self.image_level_ann_file = image_level_ann_file
44
+ super().__init__(**kwargs)
45
+
46
+ def load_data_list(self) -> List[dict]:
47
+ """Load annotations from an annotation file named as ``self.ann_file``
48
+
49
+ Returns:
50
+ List[dict]: A list of annotation.
51
+ """
52
+ classes_names, label_id_mapping = self._parse_label_file(
53
+ self.label_file)
54
+ self._metainfo['classes'] = classes_names
55
+ self.label_id_mapping = label_id_mapping
56
+
57
+ if self.image_level_ann_file is not None:
58
+ img_level_anns = self._parse_img_level_ann(
59
+ self.image_level_ann_file)
60
+ else:
61
+ img_level_anns = None
62
+
63
+ # OpenImagesMetric can get the relation matrix from the dataset meta
64
+ relation_matrix = self._get_relation_matrix(self.hierarchy_file)
65
+ self._metainfo['RELATION_MATRIX'] = relation_matrix
66
+
67
+ data_list = []
68
+ with get_local_path(
69
+ self.ann_file, backend_args=self.backend_args) as local_path:
70
+ with open(local_path, 'r') as f:
71
+ reader = csv.reader(f)
72
+ last_img_id = None
73
+ instances = []
74
+ for i, line in enumerate(reader):
75
+ if i == 0:
76
+ continue
77
+ img_id = line[0]
78
+ if last_img_id is None:
79
+ last_img_id = img_id
80
+ label_id = line[2]
81
+ assert label_id in self.label_id_mapping
82
+ label = int(self.label_id_mapping[label_id])
83
+ bbox = [
84
+ float(line[4]), # xmin
85
+ float(line[6]), # ymin
86
+ float(line[5]), # xmax
87
+ float(line[7]) # ymax
88
+ ]
89
+ is_occluded = True if int(line[8]) == 1 else False
90
+ is_truncated = True if int(line[9]) == 1 else False
91
+ is_group_of = True if int(line[10]) == 1 else False
92
+ is_depiction = True if int(line[11]) == 1 else False
93
+ is_inside = True if int(line[12]) == 1 else False
94
+
95
+ instance = dict(
96
+ bbox=bbox,
97
+ bbox_label=label,
98
+ ignore_flag=0,
99
+ is_occluded=is_occluded,
100
+ is_truncated=is_truncated,
101
+ is_group_of=is_group_of,
102
+ is_depiction=is_depiction,
103
+ is_inside=is_inside)
104
+ last_img_path = osp.join(self.data_prefix['img'],
105
+ f'{last_img_id}.jpg')
106
+ if img_id != last_img_id:
107
+ # switch to a new image, record previous image's data.
108
+ data_info = dict(
109
+ img_path=last_img_path,
110
+ img_id=last_img_id,
111
+ instances=instances,
112
+ )
113
+ data_list.append(data_info)
114
+ instances = []
115
+ instances.append(instance)
116
+ last_img_id = img_id
117
+ data_list.append(
118
+ dict(
119
+ img_path=last_img_path,
120
+ img_id=last_img_id,
121
+ instances=instances,
122
+ ))
123
+
124
+ # add image metas to data list
125
+ img_metas = load(
126
+ self.meta_file, file_format='pkl', backend_args=self.backend_args)
127
+ assert len(img_metas) == len(data_list)
128
+ for i, meta in enumerate(img_metas):
129
+ img_id = data_list[i]['img_id']
130
+ assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
131
+ h, w = meta['ori_shape'][:2]
132
+ data_list[i]['height'] = h
133
+ data_list[i]['width'] = w
134
+ # denormalize bboxes
135
+ for j in range(len(data_list[i]['instances'])):
136
+ data_list[i]['instances'][j]['bbox'][0] *= w
137
+ data_list[i]['instances'][j]['bbox'][2] *= w
138
+ data_list[i]['instances'][j]['bbox'][1] *= h
139
+ data_list[i]['instances'][j]['bbox'][3] *= h
140
+ # add image-level annotation
141
+ if img_level_anns is not None:
142
+ img_labels = []
143
+ confidences = []
144
+ img_ann_list = img_level_anns.get(img_id, [])
145
+ for ann in img_ann_list:
146
+ img_labels.append(int(ann['image_level_label']))
147
+ confidences.append(float(ann['confidence']))
148
+ data_list[i]['image_level_labels'] = np.array(
149
+ img_labels, dtype=np.int64)
150
+ data_list[i]['confidences'] = np.array(
151
+ confidences, dtype=np.float32)
152
+ return data_list
153
+
154
+ def _parse_label_file(self, label_file: str) -> tuple:
155
+ """Get classes name and index mapping from cls-label-description file.
156
+
157
+ Args:
158
+ label_file (str): File path of the label description file that
159
+ maps the classes names in MID format to their short
160
+ descriptions.
161
+
162
+ Returns:
163
+ tuple: Class name of OpenImages.
164
+ """
165
+
166
+ index_list = []
167
+ classes_names = []
168
+ with get_local_path(
169
+ label_file, backend_args=self.backend_args) as local_path:
170
+ with open(local_path, 'r') as f:
171
+ reader = csv.reader(f)
172
+ for line in reader:
173
+ # self.cat2label[line[0]] = line[1]
174
+ classes_names.append(line[1])
175
+ index_list.append(line[0])
176
+ index_mapping = {index: i for i, index in enumerate(index_list)}
177
+ return classes_names, index_mapping
178
+
179
+ def _parse_img_level_ann(self,
180
+ img_level_ann_file: str) -> Dict[str, List[dict]]:
181
+ """Parse image level annotations from csv style ann_file.
182
+
183
+ Args:
184
+ img_level_ann_file (str): CSV style image level annotation
185
+ file path.
186
+
187
+ Returns:
188
+ Dict[str, List[dict]]: Annotations where item of the defaultdict
189
+ indicates an image, each of which has (n) dicts.
190
+ Keys of dicts are:
191
+
192
+ - `image_level_label` (int): Label id.
193
+ - `confidence` (float): Labels that are human-verified to be
194
+ present in an image have confidence = 1 (positive labels).
195
+ Labels that are human-verified to be absent from an image
196
+ have confidence = 0 (negative labels). Machine-generated
197
+ labels have fractional confidences, generally >= 0.5.
198
+ The higher the confidence, the smaller the chance for
199
+ the label to be a false positive.
200
+ """
201
+
202
+ item_lists = defaultdict(list)
203
+ with get_local_path(
204
+ img_level_ann_file,
205
+ backend_args=self.backend_args) as local_path:
206
+ with open(local_path, 'r') as f:
207
+ reader = csv.reader(f)
208
+ for i, line in enumerate(reader):
209
+ if i == 0:
210
+ continue
211
+ img_id = line[0]
212
+ item_lists[img_id].append(
213
+ dict(
214
+ image_level_label=int(
215
+ self.label_id_mapping[line[2]]),
216
+ confidence=float(line[3])))
217
+ return item_lists
218
+
219
+ def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
220
+ """Get the matrix of class hierarchy from the hierarchy file. Hierarchy
221
+ for 600 classes can be found at https://storage.googleapis.com/openimag
222
+ es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
223
+
224
+ Args:
225
+ hierarchy_file (str): File path to the hierarchy for classes.
226
+
227
+ Returns:
228
+ np.ndarray: The matrix of the corresponding relationship between
229
+ the parent class and the child class, of shape
230
+ (class_num, class_num).
231
+ """ # noqa
232
+
233
+ hierarchy = load(
234
+ hierarchy_file, file_format='json', backend_args=self.backend_args)
235
+ class_num = len(self._metainfo['classes'])
236
+ relation_matrix = np.eye(class_num, class_num)
237
+ relation_matrix = self._convert_hierarchy_tree(hierarchy,
238
+ relation_matrix)
239
+ return relation_matrix
240
+
241
+ def _convert_hierarchy_tree(self,
242
+ hierarchy_map: dict,
243
+ relation_matrix: np.ndarray,
244
+ parents: list = [],
245
+ get_all_parents: bool = True) -> np.ndarray:
246
+ """Get matrix of the corresponding relationship between the parent
247
+ class and the child class.
248
+
249
+ Args:
250
+ hierarchy_map (dict): Including label name and corresponding
251
+ subcategory. Keys of dicts are:
252
+
253
+ - `LabeName` (str): Name of the label.
254
+ - `Subcategory` (dict | list): Corresponding subcategory(ies).
255
+ relation_matrix (ndarray): The matrix of the corresponding
256
+ relationship between the parent class and the child class,
257
+ of shape (class_num, class_num).
258
+ parents (list): Corresponding parent class.
259
+ get_all_parents (bool): Whether get all parent names.
260
+ Default: True
261
+
262
+ Returns:
263
+ ndarray: The matrix of the corresponding relationship between
264
+ the parent class and the child class, of shape
265
+ (class_num, class_num).
266
+ """
267
+
268
+ if 'Subcategory' in hierarchy_map:
269
+ for node in hierarchy_map['Subcategory']:
270
+ if 'LabelName' in node:
271
+ children_name = node['LabelName']
272
+ children_index = self.label_id_mapping[children_name]
273
+ children = [children_index]
274
+ else:
275
+ continue
276
+ if len(parents) > 0:
277
+ for parent_index in parents:
278
+ if get_all_parents:
279
+ children.append(parent_index)
280
+ relation_matrix[children_index, parent_index] = 1
281
+ relation_matrix = self._convert_hierarchy_tree(
282
+ node, relation_matrix, parents=children)
283
+ return relation_matrix
284
+
285
+ def _join_prefix(self):
286
+ """Join ``self.data_root`` with annotation path."""
287
+ super()._join_prefix()
288
+ if not is_abs(self.label_file) and self.label_file:
289
+ self.label_file = osp.join(self.data_root, self.label_file)
290
+ if not is_abs(self.meta_file) and self.meta_file:
291
+ self.meta_file = osp.join(self.data_root, self.meta_file)
292
+ if not is_abs(self.hierarchy_file) and self.hierarchy_file:
293
+ self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
294
+ if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
295
+ self.image_level_ann_file = osp.join(self.data_root,
296
+ self.image_level_ann_file)
297
+
298
+
299
+ @DATASETS.register_module()
300
+ class OpenImagesChallengeDataset(OpenImagesDataset):
301
+ """Open Images Challenge dataset for detection.
302
+
303
+ Args:
304
+ ann_file (str): Open Images Challenge box annotation in txt format.
305
+ """
306
+
307
+ METAINFO: dict = dict(dataset_type='oid_challenge')
308
+
309
+ def __init__(self, ann_file: str, **kwargs) -> None:
310
+ if not ann_file.endswith('txt'):
311
+ raise TypeError('The annotation file of Open Images Challenge '
312
+ 'should be a txt file.')
313
+
314
+ super().__init__(ann_file=ann_file, **kwargs)
315
+
316
+ def load_data_list(self) -> List[dict]:
317
+ """Load annotations from an annotation file named as ``self.ann_file``
318
+
319
+ Returns:
320
+ List[dict]: A list of annotation.
321
+ """
322
+ classes_names, label_id_mapping = self._parse_label_file(
323
+ self.label_file)
324
+ self._metainfo['classes'] = classes_names
325
+ self.label_id_mapping = label_id_mapping
326
+
327
+ if self.image_level_ann_file is not None:
328
+ img_level_anns = self._parse_img_level_ann(
329
+ self.image_level_ann_file)
330
+ else:
331
+ img_level_anns = None
332
+
333
+ # OpenImagesMetric can get the relation matrix from the dataset meta
334
+ relation_matrix = self._get_relation_matrix(self.hierarchy_file)
335
+ self._metainfo['RELATION_MATRIX'] = relation_matrix
336
+
337
+ data_list = []
338
+ with get_local_path(
339
+ self.ann_file, backend_args=self.backend_args) as local_path:
340
+ with open(local_path, 'r') as f:
341
+ lines = f.readlines()
342
+ i = 0
343
+ while i < len(lines):
344
+ instances = []
345
+ filename = lines[i].rstrip()
346
+ i += 2
347
+ img_gt_size = int(lines[i])
348
+ i += 1
349
+ for j in range(img_gt_size):
350
+ sp = lines[i + j].split()
351
+ instances.append(
352
+ dict(
353
+ bbox=[
354
+ float(sp[1]),
355
+ float(sp[2]),
356
+ float(sp[3]),
357
+ float(sp[4])
358
+ ],
359
+ bbox_label=int(sp[0]) - 1, # labels begin from 1
360
+ ignore_flag=0,
361
+ is_group_ofs=True if int(sp[5]) == 1 else False))
362
+ i += img_gt_size
363
+ data_list.append(
364
+ dict(
365
+ img_path=osp.join(self.data_prefix['img'], filename),
366
+ instances=instances,
367
+ ))
368
+
369
+ # add image metas to data list
370
+ img_metas = load(
371
+ self.meta_file, file_format='pkl', backend_args=self.backend_args)
372
+ assert len(img_metas) == len(data_list)
373
+ for i, meta in enumerate(img_metas):
374
+ img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
375
+ assert img_id == osp.split(meta['filename'])[-1][:-4]
376
+ h, w = meta['ori_shape'][:2]
377
+ data_list[i]['height'] = h
378
+ data_list[i]['width'] = w
379
+ data_list[i]['img_id'] = img_id
380
+ # denormalize bboxes
381
+ for j in range(len(data_list[i]['instances'])):
382
+ data_list[i]['instances'][j]['bbox'][0] *= w
383
+ data_list[i]['instances'][j]['bbox'][2] *= w
384
+ data_list[i]['instances'][j]['bbox'][1] *= h
385
+ data_list[i]['instances'][j]['bbox'][3] *= h
386
+ # add image-level annotation
387
+ if img_level_anns is not None:
388
+ img_labels = []
389
+ confidences = []
390
+ img_ann_list = img_level_anns.get(img_id, [])
391
+ for ann in img_ann_list:
392
+ img_labels.append(int(ann['image_level_label']))
393
+ confidences.append(float(ann['confidence']))
394
+ data_list[i]['image_level_labels'] = np.array(
395
+ img_labels, dtype=np.int64)
396
+ data_list[i]['confidences'] = np.array(
397
+ confidences, dtype=np.float32)
398
+ return data_list
399
+
400
+ def _parse_label_file(self, label_file: str) -> tuple:
401
+ """Get classes name and index mapping from cls-label-description file.
402
+
403
+ Args:
404
+ label_file (str): File path of the label description file that
405
+ maps the classes names in MID format to their short
406
+ descriptions.
407
+
408
+ Returns:
409
+ tuple: Class name of OpenImages.
410
+ """
411
+ label_list = []
412
+ id_list = []
413
+ index_mapping = {}
414
+ with get_local_path(
415
+ label_file, backend_args=self.backend_args) as local_path:
416
+ with open(local_path, 'r') as f:
417
+ reader = csv.reader(f)
418
+ for line in reader:
419
+ label_name = line[0]
420
+ label_id = int(line[2])
421
+ label_list.append(line[1])
422
+ id_list.append(label_id)
423
+ index_mapping[label_name] = label_id - 1
424
+ indexes = np.argsort(id_list)
425
+ classes_names = []
426
+ for index in indexes:
427
+ classes_names.append(label_list[index])
428
+ return classes_names, index_mapping
429
+
430
+ def _parse_img_level_ann(self, image_level_ann_file):
431
+ """Parse image level annotations from csv style ann_file.
432
+
433
+ Args:
434
+ image_level_ann_file (str): CSV style image level annotation
435
+ file path.
436
+
437
+ Returns:
438
+ defaultdict[list[dict]]: Annotations where item of the defaultdict
439
+ indicates an image, each of which has (n) dicts.
440
+ Keys of dicts are:
441
+
442
+ - `image_level_label` (int): of shape 1.
443
+ - `confidence` (float): of shape 1.
444
+ """
445
+
446
+ item_lists = defaultdict(list)
447
+ with get_local_path(
448
+ image_level_ann_file,
449
+ backend_args=self.backend_args) as local_path:
450
+ with open(local_path, 'r') as f:
451
+ reader = csv.reader(f)
452
+ i = -1
453
+ for line in reader:
454
+ i += 1
455
+ if i == 0:
456
+ continue
457
+ else:
458
+ img_id = line[0]
459
+ label_id = line[1]
460
+ assert label_id in self.label_id_mapping
461
+ image_level_label = int(
462
+ self.label_id_mapping[label_id])
463
+ confidence = float(line[2])
464
+ item_lists[img_id].append(
465
+ dict(
466
+ image_level_label=image_level_label,
467
+ confidence=confidence))
468
+ return item_lists
469
+
470
+ def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
471
+ """Get the matrix of class hierarchy from the hierarchy file.
472
+
473
+ Args:
474
+ hierarchy_file (str): File path to the hierarchy for classes.
475
+
476
+ Returns:
477
+ np.ndarray: The matrix of the corresponding
478
+ relationship between the parent class and the child class,
479
+ of shape (class_num, class_num).
480
+ """
481
+ with get_local_path(
482
+ hierarchy_file, backend_args=self.backend_args) as local_path:
483
+ class_label_tree = np.load(local_path, allow_pickle=True)
484
+ return class_label_tree[1:, 1:]
mmdet/datasets/samplers/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .batch_sampler import AspectRatioBatchSampler
3
+ from .class_aware_sampler import ClassAwareSampler
4
+ from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
5
+
6
+ __all__ = [
7
+ 'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
8
+ 'GroupMultiSourceSampler'
9
+ ]
mmdet/datasets/samplers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (421 Bytes). View file
 
mmdet/datasets/samplers/__pycache__/batch_sampler.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
mmdet/datasets/samplers/__pycache__/class_aware_sampler.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
mmdet/datasets/samplers/__pycache__/multi_source_sampler.cpython-310.pyc ADDED
Binary file (8.51 kB). View file
 
mmdet/datasets/samplers/batch_sampler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Sequence
3
+
4
+ from torch.utils.data import BatchSampler, Sampler
5
+
6
+ from mmdet.registry import DATA_SAMPLERS
7
+
8
+
9
+ # TODO: maybe replace with a data_loader wrapper
10
+ @DATA_SAMPLERS.register_module()
11
+ class AspectRatioBatchSampler(BatchSampler):
12
+ """A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
13
+
14
+ >= 1) into a same batch.
15
+
16
+ Args:
17
+ sampler (Sampler): Base sampler.
18
+ batch_size (int): Size of mini-batch.
19
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
20
+ its size would be less than ``batch_size``.
21
+ """
22
+
23
+ def __init__(self,
24
+ sampler: Sampler,
25
+ batch_size: int,
26
+ drop_last: bool = False) -> None:
27
+ if not isinstance(sampler, Sampler):
28
+ raise TypeError('sampler should be an instance of ``Sampler``, '
29
+ f'but got {sampler}')
30
+ if not isinstance(batch_size, int) or batch_size <= 0:
31
+ raise ValueError('batch_size should be a positive integer value, '
32
+ f'but got batch_size={batch_size}')
33
+ self.sampler = sampler
34
+ self.batch_size = batch_size
35
+ self.drop_last = drop_last
36
+ # two groups for w < h and w >= h
37
+ self._aspect_ratio_buckets = [[] for _ in range(2)]
38
+
39
+ def __iter__(self) -> Sequence[int]:
40
+ for idx in self.sampler:
41
+ data_info = self.sampler.dataset.get_data_info(idx)
42
+ width, height = data_info['width'], data_info['height']
43
+ bucket_id = 0 if width < height else 1
44
+ bucket = self._aspect_ratio_buckets[bucket_id]
45
+ bucket.append(idx)
46
+ # yield a batch of indices in the same aspect ratio group
47
+ if len(bucket) == self.batch_size:
48
+ yield bucket[:]
49
+ del bucket[:]
50
+
51
+ # yield the rest data and reset the bucket
52
+ left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
53
+ 1]
54
+ self._aspect_ratio_buckets = [[] for _ in range(2)]
55
+ while len(left_data) > 0:
56
+ if len(left_data) <= self.batch_size:
57
+ if not self.drop_last:
58
+ yield left_data[:]
59
+ left_data = []
60
+ else:
61
+ yield left_data[:self.batch_size]
62
+ left_data = left_data[self.batch_size:]
63
+
64
+ def __len__(self) -> int:
65
+ if self.drop_last:
66
+ return len(self.sampler) // self.batch_size
67
+ else:
68
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
mmdet/datasets/samplers/class_aware_sampler.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Dict, Iterator, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmengine.dataset import BaseDataset
8
+ from mmengine.dist import get_dist_info, sync_random_seed
9
+ from torch.utils.data import Sampler
10
+
11
+ from mmdet.registry import DATA_SAMPLERS
12
+
13
+
14
+ @DATA_SAMPLERS.register_module()
15
+ class ClassAwareSampler(Sampler):
16
+ r"""Sampler that restricts data loading to the label of the dataset.
17
+
18
+ A class-aware sampling strategy to effectively tackle the
19
+ non-uniform class distribution. The length of the training data is
20
+ consistent with source data. Simple improvements based on `Relay
21
+ Backpropagation for Effective Learning of Deep Convolutional
22
+ Neural Networks <https://arxiv.org/abs/1512.05830>`_
23
+
24
+ The implementation logic is referred to
25
+ https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
26
+
27
+ Args:
28
+ dataset: Dataset used for sampling.
29
+ seed (int, optional): random seed used to shuffle the sampler.
30
+ This number should be identical across all
31
+ processes in the distributed group. Defaults to None.
32
+ num_sample_class (int): The number of samples taken from each
33
+ per-label list. Defaults to 1.
34
+ """
35
+
36
+ def __init__(self,
37
+ dataset: BaseDataset,
38
+ seed: Optional[int] = None,
39
+ num_sample_class: int = 1) -> None:
40
+ rank, world_size = get_dist_info()
41
+ self.rank = rank
42
+ self.world_size = world_size
43
+
44
+ self.dataset = dataset
45
+ self.epoch = 0
46
+ # Must be the same across all workers. If None, will use a
47
+ # random seed shared among workers
48
+ # (require synchronization among all workers)
49
+ if seed is None:
50
+ seed = sync_random_seed()
51
+ self.seed = seed
52
+
53
+ # The number of samples taken from each per-label list
54
+ assert num_sample_class > 0 and isinstance(num_sample_class, int)
55
+ self.num_sample_class = num_sample_class
56
+ # Get per-label image list from dataset
57
+ self.cat_dict = self.get_cat2imgs()
58
+
59
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
60
+ self.total_size = self.num_samples * self.world_size
61
+
62
+ # get number of images containing each category
63
+ self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
64
+ # filter labels without images
65
+ self.valid_cat_inds = [
66
+ i for i, length in enumerate(self.num_cat_imgs) if length != 0
67
+ ]
68
+ self.num_classes = len(self.valid_cat_inds)
69
+
70
+ def get_cat2imgs(self) -> Dict[int, list]:
71
+ """Get a dict with class as key and img_ids as values.
72
+
73
+ Returns:
74
+ dict[int, list]: A dict of per-label image list,
75
+ the item of the dict indicates a label index,
76
+ corresponds to the image index that contains the label.
77
+ """
78
+ classes = self.dataset.metainfo.get('classes', None)
79
+ if classes is None:
80
+ raise ValueError('dataset metainfo must contain `classes`')
81
+ # sort the label index
82
+ cat2imgs = {i: [] for i in range(len(classes))}
83
+ for i in range(len(self.dataset)):
84
+ cat_ids = set(self.dataset.get_cat_ids(i))
85
+ for cat in cat_ids:
86
+ cat2imgs[cat].append(i)
87
+ return cat2imgs
88
+
89
+ def __iter__(self) -> Iterator[int]:
90
+ # deterministically shuffle based on epoch
91
+ g = torch.Generator()
92
+ g.manual_seed(self.epoch + self.seed)
93
+
94
+ # initialize label list
95
+ label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
96
+ # initialize each per-label image list
97
+ data_iter_dict = dict()
98
+ for i in self.valid_cat_inds:
99
+ data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
100
+
101
+ def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
102
+ """Traverse the categories and extract `num_sample_cls` image
103
+ indexes of the corresponding categories one by one."""
104
+ id_indices = []
105
+ for _ in range(len(cls_list)):
106
+ cls_idx = next(cls_list)
107
+ for _ in range(num_sample_cls):
108
+ id = next(data_dict[cls_idx])
109
+ id_indices.append(id)
110
+ return id_indices
111
+
112
+ # deterministically shuffle based on epoch
113
+ num_bins = int(
114
+ math.ceil(self.total_size * 1.0 / self.num_classes /
115
+ self.num_sample_class))
116
+ indices = []
117
+ for i in range(num_bins):
118
+ indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
119
+ self.num_sample_class)
120
+
121
+ # fix extra samples to make it evenly divisible
122
+ if len(indices) >= self.total_size:
123
+ indices = indices[:self.total_size]
124
+ else:
125
+ indices += indices[:(self.total_size - len(indices))]
126
+ assert len(indices) == self.total_size
127
+
128
+ # subsample
129
+ offset = self.num_samples * self.rank
130
+ indices = indices[offset:offset + self.num_samples]
131
+ assert len(indices) == self.num_samples
132
+
133
+ return iter(indices)
134
+
135
+ def __len__(self) -> int:
136
+ """The number of samples in this rank."""
137
+ return self.num_samples
138
+
139
+ def set_epoch(self, epoch: int) -> None:
140
+ """Sets the epoch for this sampler.
141
+
142
+ When :attr:`shuffle=True`, this ensures all replicas use a different
143
+ random ordering for each epoch. Otherwise, the next iteration of this
144
+ sampler will yield the same ordering.
145
+
146
+ Args:
147
+ epoch (int): Epoch number.
148
+ """
149
+ self.epoch = epoch
150
+
151
+
152
+ class RandomCycleIter:
153
+ """Shuffle the list and do it again after the list have traversed.
154
+
155
+ The implementation logic is referred to
156
+ https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
157
+
158
+ Example:
159
+ >>> label_list = [0, 1, 2, 4, 5]
160
+ >>> g = torch.Generator()
161
+ >>> g.manual_seed(0)
162
+ >>> label_iter_list = RandomCycleIter(label_list, generator=g)
163
+ >>> index = next(label_iter_list)
164
+ Args:
165
+ data (list or ndarray): The data that needs to be shuffled.
166
+ generator: An torch.Generator object, which is used in setting the seed
167
+ for generating random numbers.
168
+ """ # noqa: W605
169
+
170
+ def __init__(self,
171
+ data: Union[list, np.ndarray],
172
+ generator: torch.Generator = None) -> None:
173
+ self.data = data
174
+ self.length = len(data)
175
+ self.index = torch.randperm(self.length, generator=generator).numpy()
176
+ self.i = 0
177
+ self.generator = generator
178
+
179
+ def __iter__(self) -> Iterator:
180
+ return self
181
+
182
+ def __len__(self) -> int:
183
+ return len(self.data)
184
+
185
+ def __next__(self):
186
+ if self.i == self.length:
187
+ self.index = torch.randperm(
188
+ self.length, generator=self.generator).numpy()
189
+ self.i = 0
190
+ idx = self.data[self.index[self.i]]
191
+ self.i += 1
192
+ return idx
mmdet/datasets/samplers/multi_source_sampler.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import itertools
3
+ from typing import Iterator, List, Optional, Sized, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmengine.dataset import BaseDataset
8
+ from mmengine.dist import get_dist_info, sync_random_seed
9
+ from torch.utils.data import Sampler
10
+
11
+ from mmdet.registry import DATA_SAMPLERS
12
+
13
+
14
+ @DATA_SAMPLERS.register_module()
15
+ class MultiSourceSampler(Sampler):
16
+ r"""Multi-Source Infinite Sampler.
17
+
18
+ According to the sampling ratio, sample data from different
19
+ datasets to form batches.
20
+
21
+ Args:
22
+ dataset (Sized): The dataset.
23
+ batch_size (int): Size of mini-batch.
24
+ source_ratio (list[int | float]): The sampling ratio of different
25
+ source datasets in a mini-batch.
26
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
27
+ seed (int, optional): Random seed. If None, set a random seed.
28
+ Defaults to None.
29
+
30
+ Examples:
31
+ >>> dataset_type = 'ConcatDataset'
32
+ >>> sub_dataset_type = 'CocoDataset'
33
+ >>> data_root = 'data/coco/'
34
+ >>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
35
+ >>> unsup_ann = '../coco_semi_annos/' \
36
+ >>> 'instances_train2017.1@10-unlabeled.json'
37
+ >>> dataset = dict(type=dataset_type,
38
+ >>> datasets=[
39
+ >>> dict(
40
+ >>> type=sub_dataset_type,
41
+ >>> data_root=data_root,
42
+ >>> ann_file=sup_ann,
43
+ >>> data_prefix=dict(img='train2017/'),
44
+ >>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
45
+ >>> pipeline=sup_pipeline),
46
+ >>> dict(
47
+ >>> type=sub_dataset_type,
48
+ >>> data_root=data_root,
49
+ >>> ann_file=unsup_ann,
50
+ >>> data_prefix=dict(img='train2017/'),
51
+ >>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
52
+ >>> pipeline=unsup_pipeline),
53
+ >>> ])
54
+ >>> train_dataloader = dict(
55
+ >>> batch_size=5,
56
+ >>> num_workers=5,
57
+ >>> persistent_workers=True,
58
+ >>> sampler=dict(type='MultiSourceSampler',
59
+ >>> batch_size=5, source_ratio=[1, 4]),
60
+ >>> batch_sampler=None,
61
+ >>> dataset=dataset)
62
+ """
63
+
64
+ def __init__(self,
65
+ dataset: Sized,
66
+ batch_size: int,
67
+ source_ratio: List[Union[int, float]],
68
+ shuffle: bool = True,
69
+ seed: Optional[int] = None) -> None:
70
+
71
+ assert hasattr(dataset, 'cumulative_sizes'),\
72
+ f'The dataset must be ConcatDataset, but get {dataset}'
73
+ assert isinstance(batch_size, int) and batch_size > 0, \
74
+ 'batch_size must be a positive integer value, ' \
75
+ f'but got batch_size={batch_size}'
76
+ assert isinstance(source_ratio, list), \
77
+ f'source_ratio must be a list, but got source_ratio={source_ratio}'
78
+ assert len(source_ratio) == len(dataset.cumulative_sizes), \
79
+ 'The length of source_ratio must be equal to ' \
80
+ f'the number of datasets, but got source_ratio={source_ratio}'
81
+
82
+ rank, world_size = get_dist_info()
83
+ self.rank = rank
84
+ self.world_size = world_size
85
+
86
+ self.dataset = dataset
87
+ self.cumulative_sizes = [0] + dataset.cumulative_sizes
88
+ self.batch_size = batch_size
89
+ self.source_ratio = source_ratio
90
+
91
+ self.num_per_source = [
92
+ int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
93
+ ]
94
+ self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
95
+
96
+ assert sum(self.num_per_source) == batch_size, \
97
+ 'The sum of num_per_source must be equal to ' \
98
+ f'batch_size, but get {self.num_per_source}'
99
+
100
+ self.seed = sync_random_seed() if seed is None else seed
101
+ self.shuffle = shuffle
102
+ self.source2inds = {
103
+ source: self._indices_of_rank(len(ds))
104
+ for source, ds in enumerate(dataset.datasets)
105
+ }
106
+
107
+ def _infinite_indices(self, sample_size: int) -> Iterator[int]:
108
+ """Infinitely yield a sequence of indices."""
109
+ g = torch.Generator()
110
+ g.manual_seed(self.seed)
111
+ while True:
112
+ if self.shuffle:
113
+ yield from torch.randperm(sample_size, generator=g).tolist()
114
+ else:
115
+ yield from torch.arange(sample_size).tolist()
116
+
117
+ def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
118
+ """Slice the infinite indices by rank."""
119
+ yield from itertools.islice(
120
+ self._infinite_indices(sample_size), self.rank, None,
121
+ self.world_size)
122
+
123
+ def __iter__(self) -> Iterator[int]:
124
+ batch_buffer = []
125
+ while True:
126
+ for source, num in enumerate(self.num_per_source):
127
+ batch_buffer_per_source = []
128
+ for idx in self.source2inds[source]:
129
+ idx += self.cumulative_sizes[source]
130
+ batch_buffer_per_source.append(idx)
131
+ if len(batch_buffer_per_source) == num:
132
+ batch_buffer += batch_buffer_per_source
133
+ break
134
+ yield from batch_buffer
135
+ batch_buffer = []
136
+
137
+ def __len__(self) -> int:
138
+ return len(self.dataset)
139
+
140
+ def set_epoch(self, epoch: int) -> None:
141
+ """Not supported in `epoch-based runner."""
142
+ pass
143
+
144
+
145
+ @DATA_SAMPLERS.register_module()
146
+ class GroupMultiSourceSampler(MultiSourceSampler):
147
+ r"""Group Multi-Source Infinite Sampler.
148
+
149
+ According to the sampling ratio, sample data from different
150
+ datasets but the same group to form batches.
151
+
152
+ Args:
153
+ dataset (Sized): The dataset.
154
+ batch_size (int): Size of mini-batch.
155
+ source_ratio (list[int | float]): The sampling ratio of different
156
+ source datasets in a mini-batch.
157
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
158
+ seed (int, optional): Random seed. If None, set a random seed.
159
+ Defaults to None.
160
+ """
161
+
162
+ def __init__(self,
163
+ dataset: BaseDataset,
164
+ batch_size: int,
165
+ source_ratio: List[Union[int, float]],
166
+ shuffle: bool = True,
167
+ seed: Optional[int] = None) -> None:
168
+ super().__init__(
169
+ dataset=dataset,
170
+ batch_size=batch_size,
171
+ source_ratio=source_ratio,
172
+ shuffle=shuffle,
173
+ seed=seed)
174
+
175
+ self._get_source_group_info()
176
+ self.group_source2inds = [{
177
+ source:
178
+ self._indices_of_rank(self.group2size_per_source[source][group])
179
+ for source in range(len(dataset.datasets))
180
+ } for group in range(len(self.group_ratio))]
181
+
182
+ def _get_source_group_info(self) -> None:
183
+ self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
184
+ self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
185
+ for source, dataset in enumerate(self.dataset.datasets):
186
+ for idx in range(len(dataset)):
187
+ data_info = dataset.get_data_info(idx)
188
+ width, height = data_info['width'], data_info['height']
189
+ group = 0 if width < height else 1
190
+ self.group2size_per_source[source][group] += 1
191
+ self.group2inds_per_source[source][group].append(idx)
192
+
193
+ self.group_sizes = np.zeros(2, dtype=np.int64)
194
+ for group2size in self.group2size_per_source:
195
+ for group, size in group2size.items():
196
+ self.group_sizes[group] += size
197
+ self.group_ratio = self.group_sizes / sum(self.group_sizes)
198
+
199
+ def __iter__(self) -> Iterator[int]:
200
+ batch_buffer = []
201
+ while True:
202
+ group = np.random.choice(
203
+ list(range(len(self.group_ratio))), p=self.group_ratio)
204
+ for source, num in enumerate(self.num_per_source):
205
+ batch_buffer_per_source = []
206
+ for idx in self.group_source2inds[group][source]:
207
+ idx = self.group2inds_per_source[source][group][
208
+ idx] + self.cumulative_sizes[source]
209
+ batch_buffer_per_source.append(idx)
210
+ if len(batch_buffer_per_source) == num:
211
+ batch_buffer += batch_buffer_per_source
212
+ break
213
+ yield from batch_buffer
214
+ batch_buffer = []
mmdet/datasets/transforms/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .augment_wrappers import AutoAugment, RandAugment
3
+ from .colorspace import (AutoContrast, Brightness, Color, ColorTransform,
4
+ Contrast, Equalize, Invert, Posterize, Sharpness,
5
+ Solarize, SolarizeAdd)
6
+ from .formatting import ImageToTensor, PackDetInputs, ToTensor, Transpose
7
+ from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX,
8
+ TranslateY)
9
+ from .instaboost import InstaBoost
10
+ from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
11
+ LoadEmptyAnnotations, LoadImageFromNDArray,
12
+ LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
13
+ LoadProposals)
14
+ from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
15
+ Expand, FixShapeResize, MinIoURandomCrop, MixUp,
16
+ Mosaic, Pad, PhotoMetricDistortion, RandomAffine,
17
+ RandomCenterCropPad, RandomCrop, RandomErasing,
18
+ RandomFlip, RandomShift, Resize, SegRescale,
19
+ YOLOXHSVRandomAug)
20
+ from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder
21
+
22
+ __all__ = [
23
+ 'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
24
+ 'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
25
+ 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
26
+ 'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
27
+ 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
28
+ 'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
29
+ 'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
30
+ 'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
31
+ 'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
32
+ 'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
33
+ 'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
34
+ 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
35
+ 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader'
36
+ ]
mmdet/datasets/transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
mmdet/datasets/transforms/__pycache__/augment_wrappers.cpython-310.pyc ADDED
Binary file (9.14 kB). View file
 
mmdet/datasets/transforms/__pycache__/colorspace.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
mmdet/datasets/transforms/__pycache__/formatting.cpython-310.pyc ADDED
Binary file (8.89 kB). View file