Spaces:
Runtime error
Runtime error
Upload 787 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mmdet/__init__.py +27 -0
- mmdet/__pycache__/__init__.cpython-310.pyc +0 -0
- mmdet/__pycache__/registry.cpython-310.pyc +0 -0
- mmdet/__pycache__/version.cpython-310.pyc +0 -0
- mmdet/apis/__init__.py +9 -0
- mmdet/apis/det_inferencer.py +590 -0
- mmdet/apis/inference.py +233 -0
- mmdet/datasets/__init__.py +27 -0
- mmdet/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/base_det_dataset.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/coco.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/coco_panoptic.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/crowdhuman.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/dataset_wrappers.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/deepfashion.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/lvis.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/objects365.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/openimages.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/utils.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/voc.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/wider_face.cpython-310.pyc +0 -0
- mmdet/datasets/__pycache__/xml_style.cpython-310.pyc +0 -0
- mmdet/datasets/api_wrappers/__init__.py +4 -0
- mmdet/datasets/api_wrappers/__pycache__/__init__.cpython-310.pyc +0 -0
- mmdet/datasets/api_wrappers/__pycache__/coco_api.cpython-310.pyc +0 -0
- mmdet/datasets/api_wrappers/coco_api.py +137 -0
- mmdet/datasets/base_det_dataset.py +120 -0
- mmdet/datasets/cityscapes.py +61 -0
- mmdet/datasets/coco.py +196 -0
- mmdet/datasets/coco_panoptic.py +287 -0
- mmdet/datasets/crowdhuman.py +159 -0
- mmdet/datasets/dataset_wrappers.py +169 -0
- mmdet/datasets/deepfashion.py +19 -0
- mmdet/datasets/lvis.py +638 -0
- mmdet/datasets/objects365.py +284 -0
- mmdet/datasets/openimages.py +484 -0
- mmdet/datasets/samplers/__init__.py +9 -0
- mmdet/datasets/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
- mmdet/datasets/samplers/__pycache__/batch_sampler.cpython-310.pyc +0 -0
- mmdet/datasets/samplers/__pycache__/class_aware_sampler.cpython-310.pyc +0 -0
- mmdet/datasets/samplers/__pycache__/multi_source_sampler.cpython-310.pyc +0 -0
- mmdet/datasets/samplers/batch_sampler.py +68 -0
- mmdet/datasets/samplers/class_aware_sampler.py +192 -0
- mmdet/datasets/samplers/multi_source_sampler.py +214 -0
- mmdet/datasets/transforms/__init__.py +36 -0
- mmdet/datasets/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
- mmdet/datasets/transforms/__pycache__/augment_wrappers.cpython-310.pyc +0 -0
- mmdet/datasets/transforms/__pycache__/colorspace.cpython-310.pyc +0 -0
- mmdet/datasets/transforms/__pycache__/formatting.cpython-310.pyc +0 -0
mmdet/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import mmcv
|
3 |
+
import mmengine
|
4 |
+
from mmengine.utils import digit_version
|
5 |
+
|
6 |
+
from .version import __version__, version_info
|
7 |
+
|
8 |
+
mmcv_minimum_version = '2.0.0rc4'
|
9 |
+
mmcv_maximum_version = '2.1.0'
|
10 |
+
mmcv_version = digit_version(mmcv.__version__)
|
11 |
+
|
12 |
+
mmengine_minimum_version = '0.7.0'
|
13 |
+
mmengine_maximum_version = '1.0.0'
|
14 |
+
mmengine_version = digit_version(mmengine.__version__)
|
15 |
+
|
16 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
17 |
+
and mmcv_version < digit_version(mmcv_maximum_version)), \
|
18 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
19 |
+
f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
|
20 |
+
|
21 |
+
assert (mmengine_version >= digit_version(mmengine_minimum_version)
|
22 |
+
and mmengine_version < digit_version(mmengine_maximum_version)), \
|
23 |
+
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
24 |
+
f'Please install mmengine>={mmengine_minimum_version}, ' \
|
25 |
+
f'<{mmengine_maximum_version}.'
|
26 |
+
|
27 |
+
__all__ = ['__version__', 'version_info', 'digit_version']
|
mmdet/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (817 Bytes). View file
|
|
mmdet/__pycache__/registry.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
mmdet/__pycache__/version.cpython-310.pyc
ADDED
Binary file (803 Bytes). View file
|
|
mmdet/apis/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .det_inferencer import DetInferencer
|
3 |
+
from .inference import (async_inference_detector, inference_detector,
|
4 |
+
init_detector)
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'init_detector', 'async_inference_detector', 'inference_detector',
|
8 |
+
'DetInferencer'
|
9 |
+
]
|
mmdet/apis/det_inferencer.py
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import os.path as osp
|
4 |
+
import warnings
|
5 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Union
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import mmengine
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmengine.dataset import Compose
|
12 |
+
from mmengine.fileio import (get_file_backend, isdir, join_path,
|
13 |
+
list_dir_or_file)
|
14 |
+
from mmengine.infer.infer import BaseInferencer, ModelType
|
15 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
16 |
+
from mmengine.registry import init_default_scope
|
17 |
+
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
18 |
+
from mmengine.visualization import Visualizer
|
19 |
+
from rich.progress import track
|
20 |
+
|
21 |
+
from mmdet.evaluation import INSTANCE_OFFSET
|
22 |
+
from mmdet.registry import DATASETS
|
23 |
+
from mmdet.structures import DetDataSample
|
24 |
+
from mmdet.structures.mask import encode_mask_results, mask2bbox
|
25 |
+
from mmdet.utils import ConfigType
|
26 |
+
from ..evaluation import get_classes
|
27 |
+
|
28 |
+
try:
|
29 |
+
from panopticapi.evaluation import VOID
|
30 |
+
from panopticapi.utils import id2rgb
|
31 |
+
except ImportError:
|
32 |
+
id2rgb = None
|
33 |
+
VOID = None
|
34 |
+
|
35 |
+
InputType = Union[str, np.ndarray]
|
36 |
+
InputsType = Union[InputType, Sequence[InputType]]
|
37 |
+
PredType = List[DetDataSample]
|
38 |
+
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
39 |
+
|
40 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
41 |
+
'.tiff', '.webp')
|
42 |
+
|
43 |
+
|
44 |
+
class DetInferencer(BaseInferencer):
|
45 |
+
"""Object Detection Inferencer.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model (str, optional): Path to the config file or the model name
|
49 |
+
defined in metafile. For example, it could be
|
50 |
+
"rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or
|
51 |
+
"configs/rtmdet/rtmdet_s_8xb32-300e_coco.py".
|
52 |
+
If model is not specified, user must provide the
|
53 |
+
`weights` saved by MMEngine which contains the config string.
|
54 |
+
Defaults to None.
|
55 |
+
weights (str, optional): Path to the checkpoint. If it is not specified
|
56 |
+
and model is a model name of metafile, the weights will be loaded
|
57 |
+
from metafile. Defaults to None.
|
58 |
+
device (str, optional): Device to run inference. If None, the available
|
59 |
+
device will be automatically used. Defaults to None.
|
60 |
+
scope (str, optional): The scope of the model. Defaults to mmdet.
|
61 |
+
palette (str): Color palette used for visualization. The order of
|
62 |
+
priority is palette -> config -> checkpoint. Defaults to 'none'.
|
63 |
+
"""
|
64 |
+
|
65 |
+
preprocess_kwargs: set = set()
|
66 |
+
forward_kwargs: set = set()
|
67 |
+
visualize_kwargs: set = {
|
68 |
+
'return_vis',
|
69 |
+
'show',
|
70 |
+
'wait_time',
|
71 |
+
'draw_pred',
|
72 |
+
'pred_score_thr',
|
73 |
+
'img_out_dir',
|
74 |
+
'no_save_vis',
|
75 |
+
}
|
76 |
+
postprocess_kwargs: set = {
|
77 |
+
'print_result',
|
78 |
+
'pred_out_dir',
|
79 |
+
'return_datasample',
|
80 |
+
'no_save_pred',
|
81 |
+
}
|
82 |
+
|
83 |
+
def __init__(self,
|
84 |
+
model: Optional[Union[ModelType, str]] = None,
|
85 |
+
weights: Optional[str] = None,
|
86 |
+
device: Optional[str] = None,
|
87 |
+
scope: Optional[str] = 'mmdet',
|
88 |
+
palette: str = 'none') -> None:
|
89 |
+
# A global counter tracking the number of images processed, for
|
90 |
+
# naming of the output images
|
91 |
+
self.num_visualized_imgs = 0
|
92 |
+
self.num_predicted_imgs = 0
|
93 |
+
self.palette = palette
|
94 |
+
init_default_scope(scope)
|
95 |
+
super().__init__(
|
96 |
+
model=model, weights=weights, device=device, scope=scope)
|
97 |
+
self.model = revert_sync_batchnorm(self.model)
|
98 |
+
|
99 |
+
def _load_weights_to_model(self, model: nn.Module,
|
100 |
+
checkpoint: Optional[dict],
|
101 |
+
cfg: Optional[ConfigType]) -> None:
|
102 |
+
"""Loading model weights and meta information from cfg and checkpoint.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
model (nn.Module): Model to load weights and meta information.
|
106 |
+
checkpoint (dict, optional): The loaded checkpoint.
|
107 |
+
cfg (Config or ConfigDict, optional): The loaded config.
|
108 |
+
"""
|
109 |
+
|
110 |
+
if checkpoint is not None:
|
111 |
+
_load_checkpoint_to_model(model, checkpoint)
|
112 |
+
checkpoint_meta = checkpoint.get('meta', {})
|
113 |
+
# save the dataset_meta in the model for convenience
|
114 |
+
if 'dataset_meta' in checkpoint_meta:
|
115 |
+
# mmdet 3.x, all keys should be lowercase
|
116 |
+
model.dataset_meta = {
|
117 |
+
k.lower(): v
|
118 |
+
for k, v in checkpoint_meta['dataset_meta'].items()
|
119 |
+
}
|
120 |
+
elif 'CLASSES' in checkpoint_meta:
|
121 |
+
# < mmdet 3.x
|
122 |
+
classes = checkpoint_meta['CLASSES']
|
123 |
+
model.dataset_meta = {'classes': classes}
|
124 |
+
else:
|
125 |
+
warnings.warn(
|
126 |
+
'dataset_meta or class names are not saved in the '
|
127 |
+
'checkpoint\'s meta data, use COCO classes by default.')
|
128 |
+
model.dataset_meta = {'classes': get_classes('coco')}
|
129 |
+
else:
|
130 |
+
warnings.warn('Checkpoint is not loaded, and the inference '
|
131 |
+
'result is calculated by the randomly initialized '
|
132 |
+
'model!')
|
133 |
+
warnings.warn('weights is None, use COCO classes by default.')
|
134 |
+
model.dataset_meta = {'classes': get_classes('coco')}
|
135 |
+
|
136 |
+
# Priority: args.palette -> config -> checkpoint
|
137 |
+
if self.palette != 'none':
|
138 |
+
model.dataset_meta['palette'] = self.palette
|
139 |
+
else:
|
140 |
+
test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset)
|
141 |
+
# lazy init. We only need the metainfo.
|
142 |
+
test_dataset_cfg['lazy_init'] = True
|
143 |
+
metainfo = DATASETS.build(test_dataset_cfg).metainfo
|
144 |
+
cfg_palette = metainfo.get('palette', None)
|
145 |
+
if cfg_palette is not None:
|
146 |
+
model.dataset_meta['palette'] = cfg_palette
|
147 |
+
else:
|
148 |
+
if 'palette' not in model.dataset_meta:
|
149 |
+
warnings.warn(
|
150 |
+
'palette does not exist, random is used by default. '
|
151 |
+
'You can also set the palette to customize.')
|
152 |
+
model.dataset_meta['palette'] = 'random'
|
153 |
+
|
154 |
+
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
155 |
+
"""Initialize the test pipeline."""
|
156 |
+
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
157 |
+
|
158 |
+
# For inference, the key of ``img_id`` is not used.
|
159 |
+
if 'meta_keys' in pipeline_cfg[-1]:
|
160 |
+
pipeline_cfg[-1]['meta_keys'] = tuple(
|
161 |
+
meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
|
162 |
+
if meta_key != 'img_id')
|
163 |
+
|
164 |
+
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
165 |
+
'LoadImageFromFile')
|
166 |
+
if load_img_idx == -1:
|
167 |
+
raise ValueError(
|
168 |
+
'LoadImageFromFile is not found in the test pipeline')
|
169 |
+
pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader'
|
170 |
+
return Compose(pipeline_cfg)
|
171 |
+
|
172 |
+
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
173 |
+
"""Returns the index of the transform in a pipeline.
|
174 |
+
|
175 |
+
If the transform is not found, returns -1.
|
176 |
+
"""
|
177 |
+
for i, transform in enumerate(pipeline_cfg):
|
178 |
+
if transform['type'] == name:
|
179 |
+
return i
|
180 |
+
return -1
|
181 |
+
|
182 |
+
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
|
183 |
+
"""Initialize visualizers.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
cfg (ConfigType): Config containing the visualizer information.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Visualizer or None: Visualizer initialized with config.
|
190 |
+
"""
|
191 |
+
visualizer = super()._init_visualizer(cfg)
|
192 |
+
visualizer.dataset_meta = self.model.dataset_meta
|
193 |
+
return visualizer
|
194 |
+
|
195 |
+
def _inputs_to_list(self, inputs: InputsType) -> list:
|
196 |
+
"""Preprocess the inputs to a list.
|
197 |
+
|
198 |
+
Preprocess inputs to a list according to its type:
|
199 |
+
|
200 |
+
- list or tuple: return inputs
|
201 |
+
- str:
|
202 |
+
- Directory path: return all files in the directory
|
203 |
+
- other cases: return a list containing the string. The string
|
204 |
+
could be a path to file, a url or other types of string according
|
205 |
+
to the task.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
inputs (InputsType): Inputs for the inferencer.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
list: List of input for the :meth:`preprocess`.
|
212 |
+
"""
|
213 |
+
if isinstance(inputs, str):
|
214 |
+
backend = get_file_backend(inputs)
|
215 |
+
if hasattr(backend, 'isdir') and isdir(inputs):
|
216 |
+
# Backends like HttpsBackend do not implement `isdir`, so only
|
217 |
+
# those backends that implement `isdir` could accept the inputs
|
218 |
+
# as a directory
|
219 |
+
filename_list = list_dir_or_file(
|
220 |
+
inputs, list_dir=False, suffix=IMG_EXTENSIONS)
|
221 |
+
inputs = [
|
222 |
+
join_path(inputs, filename) for filename in filename_list
|
223 |
+
]
|
224 |
+
|
225 |
+
if not isinstance(inputs, (list, tuple)):
|
226 |
+
inputs = [inputs]
|
227 |
+
|
228 |
+
return list(inputs)
|
229 |
+
|
230 |
+
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
|
231 |
+
"""Process the inputs into a model-feedable format.
|
232 |
+
|
233 |
+
Customize your preprocess by overriding this method. Preprocess should
|
234 |
+
return an iterable object, of which each item will be used as the
|
235 |
+
input of ``model.test_step``.
|
236 |
+
|
237 |
+
``BaseInferencer.preprocess`` will return an iterable chunked data,
|
238 |
+
which will be used in __call__ like this:
|
239 |
+
|
240 |
+
.. code-block:: python
|
241 |
+
|
242 |
+
def __call__(self, inputs, batch_size=1, **kwargs):
|
243 |
+
chunked_data = self.preprocess(inputs, batch_size, **kwargs)
|
244 |
+
for batch in chunked_data:
|
245 |
+
preds = self.forward(batch, **kwargs)
|
246 |
+
|
247 |
+
Args:
|
248 |
+
inputs (InputsType): Inputs given by user.
|
249 |
+
batch_size (int): batch size. Defaults to 1.
|
250 |
+
|
251 |
+
Yields:
|
252 |
+
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
253 |
+
"""
|
254 |
+
chunked_data = self._get_chunk_data(inputs, batch_size)
|
255 |
+
yield from map(self.collate_fn, chunked_data)
|
256 |
+
|
257 |
+
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
|
258 |
+
"""Get batch data from inputs.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
inputs (Iterable): An iterable dataset.
|
262 |
+
chunk_size (int): Equivalent to batch size.
|
263 |
+
|
264 |
+
Yields:
|
265 |
+
list: batch data.
|
266 |
+
"""
|
267 |
+
inputs_iter = iter(inputs)
|
268 |
+
while True:
|
269 |
+
try:
|
270 |
+
chunk_data = []
|
271 |
+
for _ in range(chunk_size):
|
272 |
+
inputs_ = next(inputs_iter)
|
273 |
+
chunk_data.append((inputs_, self.pipeline(inputs_)))
|
274 |
+
yield chunk_data
|
275 |
+
except StopIteration:
|
276 |
+
if chunk_data:
|
277 |
+
yield chunk_data
|
278 |
+
break
|
279 |
+
|
280 |
+
# TODO: Video and Webcam are currently not supported and
|
281 |
+
# may consume too much memory if your input folder has a lot of images.
|
282 |
+
# We will be optimized later.
|
283 |
+
def __call__(self,
|
284 |
+
inputs: InputsType,
|
285 |
+
batch_size: int = 1,
|
286 |
+
return_vis: bool = False,
|
287 |
+
show: bool = False,
|
288 |
+
wait_time: int = 0,
|
289 |
+
no_save_vis: bool = False,
|
290 |
+
draw_pred: bool = True,
|
291 |
+
pred_score_thr: float = 0.3,
|
292 |
+
return_datasample: bool = False,
|
293 |
+
print_result: bool = False,
|
294 |
+
no_save_pred: bool = True,
|
295 |
+
out_dir: str = '',
|
296 |
+
**kwargs) -> dict:
|
297 |
+
"""Call the inferencer.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
inputs (InputsType): Inputs for the inferencer.
|
301 |
+
batch_size (int): Inference batch size. Defaults to 1.
|
302 |
+
show (bool): Whether to display the visualization results in a
|
303 |
+
popup window. Defaults to False.
|
304 |
+
wait_time (float): The interval of show (s). Defaults to 0.
|
305 |
+
no_save_vis (bool): Whether to force not to save prediction
|
306 |
+
vis results. Defaults to False.
|
307 |
+
draw_pred (bool): Whether to draw predicted bounding boxes.
|
308 |
+
Defaults to True.
|
309 |
+
pred_score_thr (float): Minimum score of bboxes to draw.
|
310 |
+
Defaults to 0.3.
|
311 |
+
return_datasample (bool): Whether to return results as
|
312 |
+
:obj:`DetDataSample`. Defaults to False.
|
313 |
+
print_result (bool): Whether to print the inference result w/o
|
314 |
+
visualization to the console. Defaults to False.
|
315 |
+
no_save_pred (bool): Whether to force not to save prediction
|
316 |
+
results. Defaults to True.
|
317 |
+
out_file: Dir to save the inference results or
|
318 |
+
visualization. If left as empty, no file will be saved.
|
319 |
+
Defaults to ''.
|
320 |
+
|
321 |
+
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
322 |
+
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
323 |
+
Each key in kwargs should be in the corresponding set of
|
324 |
+
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
325 |
+
and ``postprocess_kwargs``.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
dict: Inference and visualization results.
|
329 |
+
"""
|
330 |
+
(
|
331 |
+
preprocess_kwargs,
|
332 |
+
forward_kwargs,
|
333 |
+
visualize_kwargs,
|
334 |
+
postprocess_kwargs,
|
335 |
+
) = self._dispatch_kwargs(**kwargs)
|
336 |
+
|
337 |
+
ori_inputs = self._inputs_to_list(inputs)
|
338 |
+
inputs = self.preprocess(
|
339 |
+
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
340 |
+
|
341 |
+
results_dict = {'predictions': [], 'visualization': []}
|
342 |
+
for ori_inputs, data in track(inputs, description='Inference'):
|
343 |
+
preds = self.forward(data, **forward_kwargs)
|
344 |
+
visualization = self.visualize(
|
345 |
+
ori_inputs,
|
346 |
+
preds,
|
347 |
+
return_vis=return_vis,
|
348 |
+
show=show,
|
349 |
+
wait_time=wait_time,
|
350 |
+
draw_pred=draw_pred,
|
351 |
+
pred_score_thr=pred_score_thr,
|
352 |
+
no_save_vis=no_save_vis,
|
353 |
+
img_out_dir=out_dir,
|
354 |
+
**visualize_kwargs)
|
355 |
+
results = self.postprocess(
|
356 |
+
preds,
|
357 |
+
visualization,
|
358 |
+
return_datasample=return_datasample,
|
359 |
+
print_result=print_result,
|
360 |
+
no_save_pred=no_save_pred,
|
361 |
+
pred_out_dir=out_dir,
|
362 |
+
**postprocess_kwargs)
|
363 |
+
results_dict['predictions'].extend(results['predictions'])
|
364 |
+
if results['visualization'] is not None:
|
365 |
+
results_dict['visualization'].extend(results['visualization'])
|
366 |
+
return results_dict
|
367 |
+
|
368 |
+
def visualize(self,
|
369 |
+
inputs: InputsType,
|
370 |
+
preds: PredType,
|
371 |
+
return_vis: bool = False,
|
372 |
+
show: bool = False,
|
373 |
+
wait_time: int = 0,
|
374 |
+
draw_pred: bool = True,
|
375 |
+
pred_score_thr: float = 0.3,
|
376 |
+
no_save_vis: bool = False,
|
377 |
+
img_out_dir: str = '',
|
378 |
+
**kwargs) -> Union[List[np.ndarray], None]:
|
379 |
+
"""Visualize predictions.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
|
383 |
+
preds (List[:obj:`DetDataSample`]): Predictions of the model.
|
384 |
+
return_vis (bool): Whether to return the visualization result.
|
385 |
+
Defaults to False.
|
386 |
+
show (bool): Whether to display the image in a popup window.
|
387 |
+
Defaults to False.
|
388 |
+
wait_time (float): The interval of show (s). Defaults to 0.
|
389 |
+
draw_pred (bool): Whether to draw predicted bounding boxes.
|
390 |
+
Defaults to True.
|
391 |
+
pred_score_thr (float): Minimum score of bboxes to draw.
|
392 |
+
Defaults to 0.3.
|
393 |
+
no_save_vis (bool): Whether to force not to save prediction
|
394 |
+
vis results. Defaults to False.
|
395 |
+
img_out_dir (str): Output directory of visualization results.
|
396 |
+
If left as empty, no file will be saved. Defaults to ''.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
List[np.ndarray] or None: Returns visualization results only if
|
400 |
+
applicable.
|
401 |
+
"""
|
402 |
+
if no_save_vis is True:
|
403 |
+
img_out_dir = ''
|
404 |
+
|
405 |
+
if not show and img_out_dir == '' and not return_vis:
|
406 |
+
return None
|
407 |
+
|
408 |
+
if self.visualizer is None:
|
409 |
+
raise ValueError('Visualization needs the "visualizer" term'
|
410 |
+
'defined in the config, but got None.')
|
411 |
+
|
412 |
+
results = []
|
413 |
+
|
414 |
+
for single_input, pred in zip(inputs, preds):
|
415 |
+
if isinstance(single_input, str):
|
416 |
+
img_bytes = mmengine.fileio.get(single_input)
|
417 |
+
img = mmcv.imfrombytes(img_bytes)
|
418 |
+
img = img[:, :, ::-1]
|
419 |
+
img_name = osp.basename(single_input)
|
420 |
+
elif isinstance(single_input, np.ndarray):
|
421 |
+
img = single_input.copy()
|
422 |
+
img_num = str(self.num_visualized_imgs).zfill(8)
|
423 |
+
img_name = f'{img_num}.jpg'
|
424 |
+
else:
|
425 |
+
raise ValueError('Unsupported input type: '
|
426 |
+
f'{type(single_input)}')
|
427 |
+
|
428 |
+
out_file = osp.join(img_out_dir, 'vis',
|
429 |
+
img_name) if img_out_dir != '' else None
|
430 |
+
|
431 |
+
self.visualizer.add_datasample(
|
432 |
+
img_name,
|
433 |
+
img,
|
434 |
+
pred,
|
435 |
+
show=show,
|
436 |
+
wait_time=wait_time,
|
437 |
+
draw_gt=False,
|
438 |
+
draw_pred=draw_pred,
|
439 |
+
pred_score_thr=pred_score_thr,
|
440 |
+
out_file=out_file,
|
441 |
+
)
|
442 |
+
results.append(self.visualizer.get_image())
|
443 |
+
self.num_visualized_imgs += 1
|
444 |
+
|
445 |
+
return results
|
446 |
+
|
447 |
+
def postprocess(
|
448 |
+
self,
|
449 |
+
preds: PredType,
|
450 |
+
visualization: Optional[List[np.ndarray]] = None,
|
451 |
+
return_datasample: bool = False,
|
452 |
+
print_result: bool = False,
|
453 |
+
no_save_pred: bool = False,
|
454 |
+
pred_out_dir: str = '',
|
455 |
+
**kwargs,
|
456 |
+
) -> Dict:
|
457 |
+
"""Process the predictions and visualization results from ``forward``
|
458 |
+
and ``visualize``.
|
459 |
+
|
460 |
+
This method should be responsible for the following tasks:
|
461 |
+
|
462 |
+
1. Convert datasamples into a json-serializable dict if needed.
|
463 |
+
2. Pack the predictions and visualization results and return them.
|
464 |
+
3. Dump or log the predictions.
|
465 |
+
|
466 |
+
Args:
|
467 |
+
preds (List[:obj:`DetDataSample`]): Predictions of the model.
|
468 |
+
visualization (Optional[np.ndarray]): Visualized predictions.
|
469 |
+
return_datasample (bool): Whether to use Datasample to store
|
470 |
+
inference results. If False, dict will be used.
|
471 |
+
print_result (bool): Whether to print the inference result w/o
|
472 |
+
visualization to the console. Defaults to False.
|
473 |
+
no_save_pred (bool): Whether to force not to save prediction
|
474 |
+
results. Defaults to False.
|
475 |
+
pred_out_dir: Dir to save the inference results w/o
|
476 |
+
visualization. If left as empty, no file will be saved.
|
477 |
+
Defaults to ''.
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
dict: Inference and visualization results with key ``predictions``
|
481 |
+
and ``visualization``.
|
482 |
+
|
483 |
+
- ``visualization`` (Any): Returned by :meth:`visualize`.
|
484 |
+
- ``predictions`` (dict or DataSample): Returned by
|
485 |
+
:meth:`forward` and processed in :meth:`postprocess`.
|
486 |
+
If ``return_datasample=False``, it usually should be a
|
487 |
+
json-serializable dict containing only basic data elements such
|
488 |
+
as strings and numbers.
|
489 |
+
"""
|
490 |
+
if no_save_pred is True:
|
491 |
+
pred_out_dir = ''
|
492 |
+
|
493 |
+
result_dict = {}
|
494 |
+
results = preds
|
495 |
+
if not return_datasample:
|
496 |
+
results = []
|
497 |
+
for pred in preds:
|
498 |
+
result = self.pred2dict(pred, pred_out_dir)
|
499 |
+
results.append(result)
|
500 |
+
elif pred_out_dir != '':
|
501 |
+
warnings.warn('Currently does not support saving datasample '
|
502 |
+
'when return_datasample is set to True. '
|
503 |
+
'Prediction results are not saved!')
|
504 |
+
# Add img to the results after printing and dumping
|
505 |
+
result_dict['predictions'] = results
|
506 |
+
if print_result:
|
507 |
+
print(result_dict)
|
508 |
+
result_dict['visualization'] = visualization
|
509 |
+
return result_dict
|
510 |
+
|
511 |
+
# TODO: The data format and fields saved in json need further discussion.
|
512 |
+
# Maybe should include model name, timestamp, filename, image info etc.
|
513 |
+
def pred2dict(self,
|
514 |
+
data_sample: DetDataSample,
|
515 |
+
pred_out_dir: str = '') -> Dict:
|
516 |
+
"""Extract elements necessary to represent a prediction into a
|
517 |
+
dictionary.
|
518 |
+
|
519 |
+
It's better to contain only basic data elements such as strings and
|
520 |
+
numbers in order to guarantee it's json-serializable.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
data_sample (:obj:`DetDataSample`): Predictions of the model.
|
524 |
+
pred_out_dir: Dir to save the inference results w/o
|
525 |
+
visualization. If left as empty, no file will be saved.
|
526 |
+
Defaults to ''.
|
527 |
+
|
528 |
+
Returns:
|
529 |
+
dict: Prediction results.
|
530 |
+
"""
|
531 |
+
is_save_pred = True
|
532 |
+
if pred_out_dir == '':
|
533 |
+
is_save_pred = False
|
534 |
+
|
535 |
+
if is_save_pred and 'img_path' in data_sample:
|
536 |
+
img_path = osp.basename(data_sample.img_path)
|
537 |
+
img_path = osp.splitext(img_path)[0]
|
538 |
+
out_img_path = osp.join(pred_out_dir, 'preds',
|
539 |
+
img_path + '_panoptic_seg.png')
|
540 |
+
out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json')
|
541 |
+
elif is_save_pred:
|
542 |
+
out_img_path = osp.join(
|
543 |
+
pred_out_dir, 'preds',
|
544 |
+
f'{self.num_predicted_imgs}_panoptic_seg.png')
|
545 |
+
out_json_path = osp.join(pred_out_dir, 'preds',
|
546 |
+
f'{self.num_predicted_imgs}.json')
|
547 |
+
self.num_predicted_imgs += 1
|
548 |
+
|
549 |
+
result = {}
|
550 |
+
if 'pred_instances' in data_sample:
|
551 |
+
masks = data_sample.pred_instances.get('masks')
|
552 |
+
pred_instances = data_sample.pred_instances.numpy()
|
553 |
+
result = {
|
554 |
+
'bboxes': pred_instances.bboxes.tolist(),
|
555 |
+
'labels': pred_instances.labels.tolist(),
|
556 |
+
'scores': pred_instances.scores.tolist()
|
557 |
+
}
|
558 |
+
if masks is not None:
|
559 |
+
if pred_instances.bboxes.sum() == 0:
|
560 |
+
# Fake bbox, such as the SOLO.
|
561 |
+
bboxes = mask2bbox(masks.cpu()).numpy().tolist()
|
562 |
+
result['bboxes'] = bboxes
|
563 |
+
encode_masks = encode_mask_results(pred_instances.masks)
|
564 |
+
for encode_mask in encode_masks:
|
565 |
+
if isinstance(encode_mask['counts'], bytes):
|
566 |
+
encode_mask['counts'] = encode_mask['counts'].decode()
|
567 |
+
result['masks'] = encode_masks
|
568 |
+
|
569 |
+
if 'pred_panoptic_seg' in data_sample:
|
570 |
+
if VOID is None:
|
571 |
+
raise RuntimeError(
|
572 |
+
'panopticapi is not installed, please install it by: '
|
573 |
+
'pip install git+https://github.com/cocodataset/'
|
574 |
+
'panopticapi.git.')
|
575 |
+
|
576 |
+
pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0]
|
577 |
+
pan[pan % INSTANCE_OFFSET == len(
|
578 |
+
self.model.dataset_meta['classes'])] = VOID
|
579 |
+
pan = id2rgb(pan).astype(np.uint8)
|
580 |
+
|
581 |
+
if is_save_pred:
|
582 |
+
mmcv.imwrite(pan[:, :, ::-1], out_img_path)
|
583 |
+
result['panoptic_seg_path'] = out_img_path
|
584 |
+
else:
|
585 |
+
result['panoptic_seg'] = pan
|
586 |
+
|
587 |
+
if is_save_pred:
|
588 |
+
mmengine.dump(result, out_json_path)
|
589 |
+
|
590 |
+
return result
|
mmdet/apis/inference.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Sequence, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from mmcv.ops import RoIPool
|
11 |
+
from mmcv.transforms import Compose
|
12 |
+
from mmengine.config import Config
|
13 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
14 |
+
from mmengine.registry import init_default_scope
|
15 |
+
from mmengine.runner import load_checkpoint
|
16 |
+
|
17 |
+
from mmdet.registry import DATASETS
|
18 |
+
from ..evaluation import get_classes
|
19 |
+
from ..registry import MODELS
|
20 |
+
from ..structures import DetDataSample, SampleList
|
21 |
+
from ..utils import get_test_pipeline_cfg
|
22 |
+
|
23 |
+
|
24 |
+
def init_detector(
|
25 |
+
config: Union[str, Path, Config],
|
26 |
+
checkpoint: Optional[str] = None,
|
27 |
+
palette: str = 'none',
|
28 |
+
device: str = 'cuda:0',
|
29 |
+
cfg_options: Optional[dict] = None,
|
30 |
+
) -> nn.Module:
|
31 |
+
"""Initialize a detector from config file.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
35 |
+
:obj:`Path`, or the config object.
|
36 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
37 |
+
will not load any weights.
|
38 |
+
palette (str): Color palette used for visualization. If palette
|
39 |
+
is stored in checkpoint, use checkpoint's palette first, otherwise
|
40 |
+
use externally passed palette. Currently, supports 'coco', 'voc',
|
41 |
+
'citys' and 'random'. Defaults to none.
|
42 |
+
device (str): The device where the anchors will be put on.
|
43 |
+
Defaults to cuda:0.
|
44 |
+
cfg_options (dict, optional): Options to override some settings in
|
45 |
+
the used config.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
nn.Module: The constructed detector.
|
49 |
+
"""
|
50 |
+
if isinstance(config, (str, Path)):
|
51 |
+
config = Config.fromfile(config)
|
52 |
+
elif not isinstance(config, Config):
|
53 |
+
raise TypeError('config must be a filename or Config object, '
|
54 |
+
f'but got {type(config)}')
|
55 |
+
if cfg_options is not None:
|
56 |
+
config.merge_from_dict(cfg_options)
|
57 |
+
elif 'init_cfg' in config.model.backbone:
|
58 |
+
config.model.backbone.init_cfg = None
|
59 |
+
init_default_scope(config.get('default_scope', 'mmdet'))
|
60 |
+
|
61 |
+
model = MODELS.build(config.model)
|
62 |
+
model = revert_sync_batchnorm(model)
|
63 |
+
if checkpoint is None:
|
64 |
+
warnings.simplefilter('once')
|
65 |
+
warnings.warn('checkpoint is None, use COCO classes by default.')
|
66 |
+
model.dataset_meta = {'classes': get_classes('coco')}
|
67 |
+
else:
|
68 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
69 |
+
# Weights converted from elsewhere may not have meta fields.
|
70 |
+
checkpoint_meta = checkpoint.get('meta', {})
|
71 |
+
|
72 |
+
# save the dataset_meta in the model for convenience
|
73 |
+
if 'dataset_meta' in checkpoint_meta:
|
74 |
+
# mmdet 3.x, all keys should be lowercase
|
75 |
+
model.dataset_meta = {
|
76 |
+
k.lower(): v
|
77 |
+
for k, v in checkpoint_meta['dataset_meta'].items()
|
78 |
+
}
|
79 |
+
elif 'CLASSES' in checkpoint_meta:
|
80 |
+
# < mmdet 3.x
|
81 |
+
classes = checkpoint_meta['CLASSES']
|
82 |
+
model.dataset_meta = {'classes': classes}
|
83 |
+
else:
|
84 |
+
warnings.simplefilter('once')
|
85 |
+
warnings.warn(
|
86 |
+
'dataset_meta or class names are not saved in the '
|
87 |
+
'checkpoint\'s meta data, use COCO classes by default.')
|
88 |
+
model.dataset_meta = {'classes': get_classes('coco')}
|
89 |
+
|
90 |
+
# Priority: args.palette -> config -> checkpoint
|
91 |
+
if palette != 'none':
|
92 |
+
model.dataset_meta['palette'] = palette
|
93 |
+
else:
|
94 |
+
test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset)
|
95 |
+
# lazy init. We only need the metainfo.
|
96 |
+
test_dataset_cfg['lazy_init'] = True
|
97 |
+
metainfo = DATASETS.build(test_dataset_cfg).metainfo
|
98 |
+
cfg_palette = metainfo.get('palette', None)
|
99 |
+
if cfg_palette is not None:
|
100 |
+
model.dataset_meta['palette'] = cfg_palette
|
101 |
+
else:
|
102 |
+
if 'palette' not in model.dataset_meta:
|
103 |
+
warnings.warn(
|
104 |
+
'palette does not exist, random is used by default. '
|
105 |
+
'You can also set the palette to customize.')
|
106 |
+
model.dataset_meta['palette'] = 'random'
|
107 |
+
|
108 |
+
model.cfg = config # save the config in the model for convenience
|
109 |
+
model.to(device)
|
110 |
+
model.eval()
|
111 |
+
return model
|
112 |
+
|
113 |
+
|
114 |
+
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
115 |
+
|
116 |
+
|
117 |
+
def inference_detector(
|
118 |
+
model: nn.Module,
|
119 |
+
imgs: ImagesType,
|
120 |
+
test_pipeline: Optional[Compose] = None
|
121 |
+
) -> Union[DetDataSample, SampleList]:
|
122 |
+
"""Inference image(s) with the detector.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
model (nn.Module): The loaded detector.
|
126 |
+
imgs (str, ndarray, Sequence[str/ndarray]):
|
127 |
+
Either image files or loaded images.
|
128 |
+
test_pipeline (:obj:`Compose`): Test pipeline.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
:obj:`DetDataSample` or list[:obj:`DetDataSample`]:
|
132 |
+
If imgs is a list or tuple, the same length list type results
|
133 |
+
will be returned, otherwise return the detection results directly.
|
134 |
+
"""
|
135 |
+
|
136 |
+
if isinstance(imgs, (list, tuple)):
|
137 |
+
is_batch = True
|
138 |
+
else:
|
139 |
+
imgs = [imgs]
|
140 |
+
is_batch = False
|
141 |
+
|
142 |
+
cfg = model.cfg
|
143 |
+
|
144 |
+
if test_pipeline is None:
|
145 |
+
cfg = cfg.copy()
|
146 |
+
test_pipeline = get_test_pipeline_cfg(cfg)
|
147 |
+
if isinstance(imgs[0], np.ndarray):
|
148 |
+
# Calling this method across libraries will result
|
149 |
+
# in module unregistered error if not prefixed with mmdet.
|
150 |
+
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
|
151 |
+
|
152 |
+
test_pipeline = Compose(test_pipeline)
|
153 |
+
|
154 |
+
if model.data_preprocessor.device.type == 'cpu':
|
155 |
+
for m in model.modules():
|
156 |
+
assert not isinstance(
|
157 |
+
m, RoIPool
|
158 |
+
), 'CPU inference with RoIPool is not supported currently.'
|
159 |
+
|
160 |
+
result_list = []
|
161 |
+
for img in imgs:
|
162 |
+
# prepare data
|
163 |
+
if isinstance(img, np.ndarray):
|
164 |
+
# TODO: remove img_id.
|
165 |
+
data_ = dict(img=img, img_id=0)
|
166 |
+
else:
|
167 |
+
# TODO: remove img_id.
|
168 |
+
data_ = dict(img_path=img, img_id=0)
|
169 |
+
# build the data pipeline
|
170 |
+
data_ = test_pipeline(data_)
|
171 |
+
|
172 |
+
data_['inputs'] = [data_['inputs']]
|
173 |
+
data_['data_samples'] = [data_['data_samples']]
|
174 |
+
|
175 |
+
# forward the model
|
176 |
+
with torch.no_grad():
|
177 |
+
results = model.test_step(data_)[0]
|
178 |
+
|
179 |
+
result_list.append(results)
|
180 |
+
|
181 |
+
if not is_batch:
|
182 |
+
return result_list[0]
|
183 |
+
else:
|
184 |
+
return result_list
|
185 |
+
|
186 |
+
|
187 |
+
# TODO: Awaiting refactoring
|
188 |
+
async def async_inference_detector(model, imgs):
|
189 |
+
"""Async inference image(s) with the detector.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
model (nn.Module): The loaded detector.
|
193 |
+
img (str | ndarray): Either image files or loaded images.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Awaitable detection results.
|
197 |
+
"""
|
198 |
+
if not isinstance(imgs, (list, tuple)):
|
199 |
+
imgs = [imgs]
|
200 |
+
|
201 |
+
cfg = model.cfg
|
202 |
+
|
203 |
+
if isinstance(imgs[0], np.ndarray):
|
204 |
+
cfg = cfg.copy()
|
205 |
+
# set loading pipeline type
|
206 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray'
|
207 |
+
|
208 |
+
# cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
209 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
210 |
+
|
211 |
+
datas = []
|
212 |
+
for img in imgs:
|
213 |
+
# prepare data
|
214 |
+
if isinstance(img, np.ndarray):
|
215 |
+
# directly add img
|
216 |
+
data = dict(img=img)
|
217 |
+
else:
|
218 |
+
# add information into dict
|
219 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
220 |
+
# build the data pipeline
|
221 |
+
data = test_pipeline(data)
|
222 |
+
datas.append(data)
|
223 |
+
|
224 |
+
for m in model.modules():
|
225 |
+
assert not isinstance(
|
226 |
+
m,
|
227 |
+
RoIPool), 'CPU inference with RoIPool is not supported currently.'
|
228 |
+
|
229 |
+
# We don't restore `torch.is_grad_enabled()` value during concurrent
|
230 |
+
# inference since execution can overlap
|
231 |
+
torch.set_grad_enabled(False)
|
232 |
+
results = await model.aforward_test(data, rescale=True)
|
233 |
+
return results
|
mmdet/datasets/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base_det_dataset import BaseDetDataset
|
3 |
+
from .cityscapes import CityscapesDataset
|
4 |
+
from .coco import CocoDataset
|
5 |
+
from .coco_panoptic import CocoPanopticDataset
|
6 |
+
from .crowdhuman import CrowdHumanDataset
|
7 |
+
from .dataset_wrappers import MultiImageMixDataset
|
8 |
+
from .deepfashion import DeepFashionDataset
|
9 |
+
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
|
10 |
+
from .objects365 import Objects365V1Dataset, Objects365V2Dataset
|
11 |
+
from .openimages import OpenImagesChallengeDataset, OpenImagesDataset
|
12 |
+
from .samplers import (AspectRatioBatchSampler, ClassAwareSampler,
|
13 |
+
GroupMultiSourceSampler, MultiSourceSampler)
|
14 |
+
from .utils import get_loading_pipeline
|
15 |
+
from .voc import VOCDataset
|
16 |
+
from .wider_face import WIDERFaceDataset
|
17 |
+
from .xml_style import XMLDataset
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 'VOCDataset',
|
21 |
+
'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 'LVISV1Dataset',
|
22 |
+
'WIDERFaceDataset', 'get_loading_pipeline', 'CocoPanopticDataset',
|
23 |
+
'MultiImageMixDataset', 'OpenImagesDataset', 'OpenImagesChallengeDataset',
|
24 |
+
'AspectRatioBatchSampler', 'ClassAwareSampler', 'MultiSourceSampler',
|
25 |
+
'GroupMultiSourceSampler', 'BaseDetDataset', 'CrowdHumanDataset',
|
26 |
+
'Objects365V1Dataset', 'Objects365V2Dataset'
|
27 |
+
]
|
mmdet/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.25 kB). View file
|
|
mmdet/datasets/__pycache__/base_det_dataset.cpython-310.pyc
ADDED
Binary file (4.52 kB). View file
|
|
mmdet/datasets/__pycache__/cityscapes.cpython-310.pyc
ADDED
Binary file (1.95 kB). View file
|
|
mmdet/datasets/__pycache__/coco.cpython-310.pyc
ADDED
Binary file (6.23 kB). View file
|
|
mmdet/datasets/__pycache__/coco_panoptic.cpython-310.pyc
ADDED
Binary file (10.6 kB). View file
|
|
mmdet/datasets/__pycache__/crowdhuman.cpython-310.pyc
ADDED
Binary file (4.45 kB). View file
|
|
mmdet/datasets/__pycache__/dataset_wrappers.cpython-310.pyc
ADDED
Binary file (5.72 kB). View file
|
|
mmdet/datasets/__pycache__/deepfashion.cpython-310.pyc
ADDED
Binary file (903 Bytes). View file
|
|
mmdet/datasets/__pycache__/lvis.cpython-310.pyc
ADDED
Binary file (23.5 kB). View file
|
|
mmdet/datasets/__pycache__/objects365.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
mmdet/datasets/__pycache__/openimages.cpython-310.pyc
ADDED
Binary file (14.2 kB). View file
|
|
mmdet/datasets/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.81 kB). View file
|
|
mmdet/datasets/__pycache__/voc.cpython-310.pyc
ADDED
Binary file (1.34 kB). View file
|
|
mmdet/datasets/__pycache__/wider_face.cpython-310.pyc
ADDED
Binary file (2.98 kB). View file
|
|
mmdet/datasets/__pycache__/xml_style.cpython-310.pyc
ADDED
Binary file (5.68 kB). View file
|
|
mmdet/datasets/api_wrappers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .coco_api import COCO, COCOeval, COCOPanoptic
|
3 |
+
|
4 |
+
__all__ = ['COCO', 'COCOeval', 'COCOPanoptic']
|
mmdet/datasets/api_wrappers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (263 Bytes). View file
|
|
mmdet/datasets/api_wrappers/__pycache__/coco_api.cpython-310.pyc
ADDED
Binary file (4.5 kB). View file
|
|
mmdet/datasets/api_wrappers/coco_api.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# This file add snake case alias for coco api
|
3 |
+
|
4 |
+
import warnings
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
import pycocotools
|
9 |
+
from pycocotools.coco import COCO as _COCO
|
10 |
+
from pycocotools.cocoeval import COCOeval as _COCOeval
|
11 |
+
|
12 |
+
|
13 |
+
class COCO(_COCO):
|
14 |
+
"""This class is almost the same as official pycocotools package.
|
15 |
+
|
16 |
+
It implements some snake case function aliases. So that the COCO class has
|
17 |
+
the same interface as LVIS class.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, annotation_file=None):
|
21 |
+
if getattr(pycocotools, '__version__', '0') >= '12.0.2':
|
22 |
+
warnings.warn(
|
23 |
+
'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
|
24 |
+
UserWarning)
|
25 |
+
super().__init__(annotation_file=annotation_file)
|
26 |
+
self.img_ann_map = self.imgToAnns
|
27 |
+
self.cat_img_map = self.catToImgs
|
28 |
+
|
29 |
+
def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
|
30 |
+
return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
|
31 |
+
|
32 |
+
def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
|
33 |
+
return self.getCatIds(cat_names, sup_names, cat_ids)
|
34 |
+
|
35 |
+
def get_img_ids(self, img_ids=[], cat_ids=[]):
|
36 |
+
return self.getImgIds(img_ids, cat_ids)
|
37 |
+
|
38 |
+
def load_anns(self, ids):
|
39 |
+
return self.loadAnns(ids)
|
40 |
+
|
41 |
+
def load_cats(self, ids):
|
42 |
+
return self.loadCats(ids)
|
43 |
+
|
44 |
+
def load_imgs(self, ids):
|
45 |
+
return self.loadImgs(ids)
|
46 |
+
|
47 |
+
|
48 |
+
# just for the ease of import
|
49 |
+
COCOeval = _COCOeval
|
50 |
+
|
51 |
+
|
52 |
+
class COCOPanoptic(COCO):
|
53 |
+
"""This wrapper is for loading the panoptic style annotation file.
|
54 |
+
|
55 |
+
The format is shown in the CocoPanopticDataset class.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
annotation_file (str, optional): Path of annotation file.
|
59 |
+
Defaults to None.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, annotation_file: Optional[str] = None) -> None:
|
63 |
+
super(COCOPanoptic, self).__init__(annotation_file)
|
64 |
+
|
65 |
+
def createIndex(self) -> None:
|
66 |
+
"""Create index."""
|
67 |
+
# create index
|
68 |
+
print('creating index...')
|
69 |
+
# anns stores 'segment_id -> annotation'
|
70 |
+
anns, cats, imgs = {}, {}, {}
|
71 |
+
img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
|
72 |
+
if 'annotations' in self.dataset:
|
73 |
+
for ann in self.dataset['annotations']:
|
74 |
+
for seg_ann in ann['segments_info']:
|
75 |
+
# to match with instance.json
|
76 |
+
seg_ann['image_id'] = ann['image_id']
|
77 |
+
img_to_anns[ann['image_id']].append(seg_ann)
|
78 |
+
# segment_id is not unique in coco dataset orz...
|
79 |
+
# annotations from different images but
|
80 |
+
# may have same segment_id
|
81 |
+
if seg_ann['id'] in anns.keys():
|
82 |
+
anns[seg_ann['id']].append(seg_ann)
|
83 |
+
else:
|
84 |
+
anns[seg_ann['id']] = [seg_ann]
|
85 |
+
|
86 |
+
# filter out annotations from other images
|
87 |
+
img_to_anns_ = defaultdict(list)
|
88 |
+
for k, v in img_to_anns.items():
|
89 |
+
img_to_anns_[k] = [x for x in v if x['image_id'] == k]
|
90 |
+
img_to_anns = img_to_anns_
|
91 |
+
|
92 |
+
if 'images' in self.dataset:
|
93 |
+
for img_info in self.dataset['images']:
|
94 |
+
img_info['segm_file'] = img_info['file_name'].replace(
|
95 |
+
'jpg', 'png')
|
96 |
+
imgs[img_info['id']] = img_info
|
97 |
+
|
98 |
+
if 'categories' in self.dataset:
|
99 |
+
for cat in self.dataset['categories']:
|
100 |
+
cats[cat['id']] = cat
|
101 |
+
|
102 |
+
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
103 |
+
for ann in self.dataset['annotations']:
|
104 |
+
for seg_ann in ann['segments_info']:
|
105 |
+
cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
|
106 |
+
|
107 |
+
print('index created!')
|
108 |
+
|
109 |
+
self.anns = anns
|
110 |
+
self.imgToAnns = img_to_anns
|
111 |
+
self.catToImgs = cat_to_imgs
|
112 |
+
self.imgs = imgs
|
113 |
+
self.cats = cats
|
114 |
+
|
115 |
+
def load_anns(self,
|
116 |
+
ids: Union[List[int], int] = []) -> Optional[List[dict]]:
|
117 |
+
"""Load anns with the specified ids.
|
118 |
+
|
119 |
+
``self.anns`` is a list of annotation lists instead of a
|
120 |
+
list of annotations.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
ids (Union[List[int], int]): Integer ids specifying anns.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
anns (List[dict], optional): Loaded ann objects.
|
127 |
+
"""
|
128 |
+
anns = []
|
129 |
+
|
130 |
+
if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
|
131 |
+
# self.anns is a list of annotation lists instead of
|
132 |
+
# a list of annotations
|
133 |
+
for id in ids:
|
134 |
+
anns += self.anns[id]
|
135 |
+
return anns
|
136 |
+
elif type(ids) == int:
|
137 |
+
return self.anns[ids]
|
mmdet/datasets/base_det_dataset.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
from mmengine.dataset import BaseDataset
|
6 |
+
from mmengine.fileio import load
|
7 |
+
from mmengine.utils import is_abs
|
8 |
+
|
9 |
+
from ..registry import DATASETS
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class BaseDetDataset(BaseDataset):
|
14 |
+
"""Base dataset for detection.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
proposal_file (str, optional): Proposals file path. Defaults to None.
|
18 |
+
file_client_args (dict): Arguments to instantiate the
|
19 |
+
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
|
20 |
+
backend_args (dict, optional): Arguments to instantiate the
|
21 |
+
corresponding backend. Defaults to None.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
*args,
|
26 |
+
seg_map_suffix: str = '.png',
|
27 |
+
proposal_file: Optional[str] = None,
|
28 |
+
file_client_args: dict = None,
|
29 |
+
backend_args: dict = None,
|
30 |
+
**kwargs) -> None:
|
31 |
+
self.seg_map_suffix = seg_map_suffix
|
32 |
+
self.proposal_file = proposal_file
|
33 |
+
self.backend_args = backend_args
|
34 |
+
if file_client_args is not None:
|
35 |
+
raise RuntimeError(
|
36 |
+
'The `file_client_args` is deprecated, '
|
37 |
+
'please use `backend_args` instead, please refer to'
|
38 |
+
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
|
39 |
+
)
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
def full_init(self) -> None:
|
43 |
+
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
44 |
+
True.
|
45 |
+
|
46 |
+
If ``lazy_init=False``, ``full_init`` will be called during the
|
47 |
+
instantiation and ``self._fully_initialized`` will be set to True. If
|
48 |
+
``obj._fully_initialized=False``, the class method decorated by
|
49 |
+
``force_full_init`` will call ``full_init`` automatically.
|
50 |
+
|
51 |
+
Several steps to initialize annotation:
|
52 |
+
|
53 |
+
- load_data_list: Load annotations from annotation file.
|
54 |
+
- load_proposals: Load proposals from proposal file, if
|
55 |
+
`self.proposal_file` is not None.
|
56 |
+
- filter data information: Filter annotations according to
|
57 |
+
filter_cfg.
|
58 |
+
- slice_data: Slice dataset according to ``self._indices``
|
59 |
+
- serialize_data: Serialize ``self.data_list`` if
|
60 |
+
``self.serialize_data`` is True.
|
61 |
+
"""
|
62 |
+
if self._fully_initialized:
|
63 |
+
return
|
64 |
+
# load data information
|
65 |
+
self.data_list = self.load_data_list()
|
66 |
+
# get proposals from file
|
67 |
+
if self.proposal_file is not None:
|
68 |
+
self.load_proposals()
|
69 |
+
# filter illegal data, such as data that has no annotations.
|
70 |
+
self.data_list = self.filter_data()
|
71 |
+
|
72 |
+
# Get subset data according to indices.
|
73 |
+
if self._indices is not None:
|
74 |
+
self.data_list = self._get_unserialized_subset(self._indices)
|
75 |
+
|
76 |
+
# serialize data_list
|
77 |
+
if self.serialize_data:
|
78 |
+
self.data_bytes, self.data_address = self._serialize_data()
|
79 |
+
|
80 |
+
self._fully_initialized = True
|
81 |
+
|
82 |
+
def load_proposals(self) -> None:
|
83 |
+
"""Load proposals from proposals file.
|
84 |
+
|
85 |
+
The `proposals_list` should be a dict[img_path: proposals]
|
86 |
+
with the same length as `data_list`. And the `proposals` should be
|
87 |
+
a `dict` or :obj:`InstanceData` usually contains following keys.
|
88 |
+
|
89 |
+
- bboxes (np.ndarry): Has a shape (num_instances, 4),
|
90 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
91 |
+
- scores (np.ndarry): Classification scores, has a shape
|
92 |
+
(num_instance, ).
|
93 |
+
"""
|
94 |
+
# TODO: Add Unit Test after fully support Dump-Proposal Metric
|
95 |
+
if not is_abs(self.proposal_file):
|
96 |
+
self.proposal_file = osp.join(self.data_root, self.proposal_file)
|
97 |
+
proposals_list = load(
|
98 |
+
self.proposal_file, backend_args=self.backend_args)
|
99 |
+
assert len(self.data_list) == len(proposals_list)
|
100 |
+
for data_info in self.data_list:
|
101 |
+
img_path = data_info['img_path']
|
102 |
+
# `file_name` is the key to obtain the proposals from the
|
103 |
+
# `proposals_list`.
|
104 |
+
file_name = osp.join(
|
105 |
+
osp.split(osp.split(img_path)[0])[-1],
|
106 |
+
osp.split(img_path)[-1])
|
107 |
+
proposals = proposals_list[file_name]
|
108 |
+
data_info['proposals'] = proposals
|
109 |
+
|
110 |
+
def get_cat_ids(self, idx: int) -> List[int]:
|
111 |
+
"""Get COCO category ids by index.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
idx (int): Index of data.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
List[int]: All categories in the image of specified index.
|
118 |
+
"""
|
119 |
+
instances = self.get_data_info(idx)['instances']
|
120 |
+
return [instance['bbox_label'] for instance in instances]
|
mmdet/datasets/cityscapes.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
|
3 |
+
# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
|
4 |
+
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from mmdet.registry import DATASETS
|
8 |
+
from .coco import CocoDataset
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class CityscapesDataset(CocoDataset):
|
13 |
+
"""Dataset for Cityscapes."""
|
14 |
+
|
15 |
+
METAINFO = {
|
16 |
+
'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train',
|
17 |
+
'motorcycle', 'bicycle'),
|
18 |
+
'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
|
19 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
|
20 |
+
}
|
21 |
+
|
22 |
+
def filter_data(self) -> List[dict]:
|
23 |
+
"""Filter annotations according to filter_cfg.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
List[dict]: Filtered results.
|
27 |
+
"""
|
28 |
+
if self.test_mode:
|
29 |
+
return self.data_list
|
30 |
+
|
31 |
+
if self.filter_cfg is None:
|
32 |
+
return self.data_list
|
33 |
+
|
34 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
35 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
36 |
+
|
37 |
+
# obtain images that contain annotation
|
38 |
+
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
|
39 |
+
# obtain images that contain annotations of the required categories
|
40 |
+
ids_in_cat = set()
|
41 |
+
for i, class_id in enumerate(self.cat_ids):
|
42 |
+
ids_in_cat |= set(self.cat_img_map[class_id])
|
43 |
+
# merge the image id sets of the two conditions and use the merged set
|
44 |
+
# to filter out images if self.filter_empty_gt=True
|
45 |
+
ids_in_cat &= ids_with_ann
|
46 |
+
|
47 |
+
valid_data_infos = []
|
48 |
+
for i, data_info in enumerate(self.data_list):
|
49 |
+
img_id = data_info['img_id']
|
50 |
+
width = data_info['width']
|
51 |
+
height = data_info['height']
|
52 |
+
all_is_crowd = all([
|
53 |
+
instance['ignore_flag'] == 1
|
54 |
+
for instance in data_info['instances']
|
55 |
+
])
|
56 |
+
if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
|
57 |
+
continue
|
58 |
+
if min(width, height) >= min_size:
|
59 |
+
valid_data_infos.append(data_info)
|
60 |
+
|
61 |
+
return valid_data_infos
|
mmdet/datasets/coco.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import os.path as osp
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
from mmengine.fileio import get_local_path
|
7 |
+
|
8 |
+
from mmdet.registry import DATASETS
|
9 |
+
from .api_wrappers import COCO
|
10 |
+
from .base_det_dataset import BaseDetDataset
|
11 |
+
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class CocoDataset(BaseDetDataset):
|
15 |
+
"""Dataset for COCO."""
|
16 |
+
|
17 |
+
METAINFO = {
|
18 |
+
'classes':
|
19 |
+
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
20 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
21 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
22 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
23 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
24 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
25 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
26 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
27 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
28 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
29 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
30 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
31 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
|
32 |
+
# palette is a list of color tuples, which is used for visualization.
|
33 |
+
'palette':
|
34 |
+
[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
|
35 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
|
36 |
+
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
|
37 |
+
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
|
38 |
+
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
|
39 |
+
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
|
40 |
+
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
|
41 |
+
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
|
42 |
+
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
|
43 |
+
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
|
44 |
+
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
|
45 |
+
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
|
46 |
+
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
|
47 |
+
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
|
48 |
+
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
|
49 |
+
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
|
50 |
+
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
|
51 |
+
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
|
52 |
+
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
|
53 |
+
(246, 0, 122), (191, 162, 208)]
|
54 |
+
}
|
55 |
+
COCOAPI = COCO
|
56 |
+
# ann_id is unique in coco dataset.
|
57 |
+
ANN_ID_UNIQUE = True
|
58 |
+
|
59 |
+
def load_data_list(self) -> List[dict]:
|
60 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
List[dict]: A list of annotation.
|
64 |
+
""" # noqa: E501
|
65 |
+
with get_local_path(
|
66 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
67 |
+
self.coco = self.COCOAPI(local_path)
|
68 |
+
# The order of returned `cat_ids` will not
|
69 |
+
# change with the order of the `classes`
|
70 |
+
self.cat_ids = self.coco.get_cat_ids(
|
71 |
+
cat_names=self.metainfo['classes'])
|
72 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
73 |
+
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
|
74 |
+
|
75 |
+
img_ids = self.coco.get_img_ids()
|
76 |
+
data_list = []
|
77 |
+
total_ann_ids = []
|
78 |
+
for img_id in img_ids:
|
79 |
+
raw_img_info = self.coco.load_imgs([img_id])[0]
|
80 |
+
raw_img_info['img_id'] = img_id
|
81 |
+
|
82 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
83 |
+
raw_ann_info = self.coco.load_anns(ann_ids)
|
84 |
+
total_ann_ids.extend(ann_ids)
|
85 |
+
|
86 |
+
parsed_data_info = self.parse_data_info({
|
87 |
+
'raw_ann_info':
|
88 |
+
raw_ann_info,
|
89 |
+
'raw_img_info':
|
90 |
+
raw_img_info
|
91 |
+
})
|
92 |
+
data_list.append(parsed_data_info)
|
93 |
+
if self.ANN_ID_UNIQUE:
|
94 |
+
assert len(set(total_ann_ids)) == len(
|
95 |
+
total_ann_ids
|
96 |
+
), f"Annotation ids in '{self.ann_file}' are not unique!"
|
97 |
+
|
98 |
+
del self.coco
|
99 |
+
|
100 |
+
return data_list
|
101 |
+
|
102 |
+
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
|
103 |
+
"""Parse raw annotation to target format.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
raw_data_info (dict): Raw data information load from ``ann_file``
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Union[dict, List[dict]]: Parsed annotation.
|
110 |
+
"""
|
111 |
+
img_info = raw_data_info['raw_img_info']
|
112 |
+
ann_info = raw_data_info['raw_ann_info']
|
113 |
+
|
114 |
+
data_info = {}
|
115 |
+
|
116 |
+
# TODO: need to change data_prefix['img'] to data_prefix['img_path']
|
117 |
+
img_path = osp.join(self.data_prefix['img_path'], img_info['file_name'])
|
118 |
+
if self.data_prefix.get('seg_path', None):
|
119 |
+
seg_map_path = osp.join(
|
120 |
+
self.data_prefix['seg_path'],
|
121 |
+
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
|
122 |
+
else:
|
123 |
+
seg_map_path = None
|
124 |
+
data_info['img_path'] = img_path
|
125 |
+
data_info['img_id'] = img_info['img_id']
|
126 |
+
data_info['seg_map_path'] = seg_map_path
|
127 |
+
data_info['height'] = img_info['height']
|
128 |
+
data_info['width'] = img_info['width']
|
129 |
+
|
130 |
+
instances = []
|
131 |
+
for i, ann in enumerate(ann_info):
|
132 |
+
instance = {}
|
133 |
+
|
134 |
+
if ann.get('ignore', False):
|
135 |
+
continue
|
136 |
+
x1, y1, w, h = ann['bbox']
|
137 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
138 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
139 |
+
if inter_w * inter_h == 0:
|
140 |
+
continue
|
141 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
142 |
+
continue
|
143 |
+
if ann['category_id'] not in self.cat_ids:
|
144 |
+
continue
|
145 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
146 |
+
|
147 |
+
if ann.get('iscrowd', False):
|
148 |
+
instance['ignore_flag'] = 1
|
149 |
+
else:
|
150 |
+
instance['ignore_flag'] = 0
|
151 |
+
instance['bbox'] = bbox
|
152 |
+
instance['bbox_label'] = self.cat2label[ann['category_id']]
|
153 |
+
|
154 |
+
if ann.get('segmentation', None):
|
155 |
+
instance['mask'] = ann['segmentation']
|
156 |
+
|
157 |
+
instances.append(instance)
|
158 |
+
data_info['instances'] = instances
|
159 |
+
return data_info
|
160 |
+
|
161 |
+
def filter_data(self) -> List[dict]:
|
162 |
+
"""Filter annotations according to filter_cfg.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
List[dict]: Filtered results.
|
166 |
+
"""
|
167 |
+
if self.test_mode:
|
168 |
+
return self.data_list
|
169 |
+
|
170 |
+
if self.filter_cfg is None:
|
171 |
+
return self.data_list
|
172 |
+
|
173 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
174 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
175 |
+
|
176 |
+
# obtain images that contain annotation
|
177 |
+
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
|
178 |
+
# obtain images that contain annotations of the required categories
|
179 |
+
ids_in_cat = set()
|
180 |
+
for i, class_id in enumerate(self.cat_ids):
|
181 |
+
ids_in_cat |= set(self.cat_img_map[class_id])
|
182 |
+
# merge the image id sets of the two conditions and use the merged set
|
183 |
+
# to filter out images if self.filter_empty_gt=True
|
184 |
+
ids_in_cat &= ids_with_ann
|
185 |
+
|
186 |
+
valid_data_infos = []
|
187 |
+
for i, data_info in enumerate(self.data_list):
|
188 |
+
img_id = data_info['img_id']
|
189 |
+
width = data_info['width']
|
190 |
+
height = data_info['height']
|
191 |
+
if filter_empty_gt and img_id not in ids_in_cat:
|
192 |
+
continue
|
193 |
+
if min(width, height) >= min_size:
|
194 |
+
valid_data_infos.append(data_info)
|
195 |
+
|
196 |
+
return valid_data_infos
|
mmdet/datasets/coco_panoptic.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from typing import Callable, List, Optional, Sequence, Union
|
4 |
+
|
5 |
+
from mmdet.registry import DATASETS
|
6 |
+
from .api_wrappers import COCOPanoptic
|
7 |
+
from .coco import CocoDataset
|
8 |
+
|
9 |
+
|
10 |
+
@DATASETS.register_module()
|
11 |
+
class CocoPanopticDataset(CocoDataset):
|
12 |
+
"""Coco dataset for Panoptic segmentation.
|
13 |
+
|
14 |
+
The annotation format is shown as follows. The `ann` field is optional
|
15 |
+
for testing.
|
16 |
+
|
17 |
+
.. code-block:: none
|
18 |
+
|
19 |
+
[
|
20 |
+
{
|
21 |
+
'filename': f'{image_id:012}.png',
|
22 |
+
'image_id':9
|
23 |
+
'segments_info':
|
24 |
+
[
|
25 |
+
{
|
26 |
+
'id': 8345037, (segment_id in panoptic png,
|
27 |
+
convert from rgb)
|
28 |
+
'category_id': 51,
|
29 |
+
'iscrowd': 0,
|
30 |
+
'bbox': (x1, y1, w, h),
|
31 |
+
'area': 24315
|
32 |
+
},
|
33 |
+
...
|
34 |
+
]
|
35 |
+
},
|
36 |
+
...
|
37 |
+
]
|
38 |
+
|
39 |
+
Args:
|
40 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
41 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
42 |
+
information. Defaults to None.
|
43 |
+
data_root (str, optional): The root directory for ``data_prefix`` and
|
44 |
+
``ann_file``. Defaults to None.
|
45 |
+
data_prefix (dict, optional): Prefix for training data. Defaults to
|
46 |
+
``dict(img=None, ann=None, seg=None)``. The prefix ``seg`` which is
|
47 |
+
for panoptic segmentation map must be not None.
|
48 |
+
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
49 |
+
indices (int or Sequence[int], optional): Support using first few
|
50 |
+
data in annotation file to facilitate training/testing on a smaller
|
51 |
+
dataset. Defaults to None which means using all ``data_infos``.
|
52 |
+
serialize_data (bool, optional): Whether to hold memory using
|
53 |
+
serialized objects, when enabled, data loader workers can use
|
54 |
+
shared RAM from master process instead of making a copy. Defaults
|
55 |
+
to True.
|
56 |
+
pipeline (list, optional): Processing pipeline. Defaults to [].
|
57 |
+
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
58 |
+
Defaults to False.
|
59 |
+
lazy_init (bool, optional): Whether to load annotation during
|
60 |
+
instantiation. In some cases, such as visualization, only the meta
|
61 |
+
information of the dataset is needed, which is not necessary to
|
62 |
+
load annotation file. ``Basedataset`` can skip load annotations to
|
63 |
+
save time by set ``lazy_init=False``. Defaults to False.
|
64 |
+
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
65 |
+
None img. The maximum extra number of cycles to get a valid
|
66 |
+
image. Defaults to 1000.
|
67 |
+
"""
|
68 |
+
|
69 |
+
METAINFO = {
|
70 |
+
'classes':
|
71 |
+
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
72 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
73 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
74 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
75 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
76 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
77 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
78 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
79 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
80 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
81 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
82 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
83 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
84 |
+
'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
|
85 |
+
'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
|
86 |
+
'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
|
87 |
+
'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
|
88 |
+
'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
|
89 |
+
'wall-wood', 'water-other', 'window-blind', 'window-other',
|
90 |
+
'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
|
91 |
+
'cabinet-merged', 'table-merged', 'floor-other-merged',
|
92 |
+
'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
|
93 |
+
'paper-merged', 'food-other-merged', 'building-other-merged',
|
94 |
+
'rock-merged', 'wall-other-merged', 'rug-merged'),
|
95 |
+
'thing_classes':
|
96 |
+
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
97 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
98 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
99 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
100 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
101 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
102 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
103 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
104 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
105 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
106 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
107 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
108 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
|
109 |
+
'stuff_classes':
|
110 |
+
('banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain',
|
111 |
+
'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house',
|
112 |
+
'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
|
113 |
+
'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
|
114 |
+
'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
|
115 |
+
'wall-wood', 'water-other', 'window-blind', 'window-other',
|
116 |
+
'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
|
117 |
+
'cabinet-merged', 'table-merged', 'floor-other-merged',
|
118 |
+
'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
|
119 |
+
'paper-merged', 'food-other-merged', 'building-other-merged',
|
120 |
+
'rock-merged', 'wall-other-merged', 'rug-merged'),
|
121 |
+
'palette':
|
122 |
+
[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
|
123 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
|
124 |
+
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
|
125 |
+
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
|
126 |
+
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
|
127 |
+
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
|
128 |
+
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
|
129 |
+
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
|
130 |
+
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
|
131 |
+
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
|
132 |
+
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
|
133 |
+
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
|
134 |
+
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
|
135 |
+
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
|
136 |
+
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
|
137 |
+
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
|
138 |
+
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
|
139 |
+
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
|
140 |
+
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
|
141 |
+
(246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
|
142 |
+
(150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
|
143 |
+
(92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
|
144 |
+
(124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
|
145 |
+
(193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
|
146 |
+
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
|
147 |
+
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
|
148 |
+
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
|
149 |
+
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
|
150 |
+
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
|
151 |
+
(146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
|
152 |
+
(96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
|
153 |
+
(208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
|
154 |
+
(0, 114, 143), (102, 102, 156), (250, 141, 255)]
|
155 |
+
}
|
156 |
+
COCOAPI = COCOPanoptic
|
157 |
+
# ann_id is not unique in coco panoptic dataset.
|
158 |
+
ANN_ID_UNIQUE = False
|
159 |
+
|
160 |
+
def __init__(self,
|
161 |
+
ann_file: str = '',
|
162 |
+
metainfo: Optional[dict] = None,
|
163 |
+
data_root: Optional[str] = None,
|
164 |
+
data_prefix: dict = dict(img=None, ann=None, seg=None),
|
165 |
+
filter_cfg: Optional[dict] = None,
|
166 |
+
indices: Optional[Union[int, Sequence[int]]] = None,
|
167 |
+
serialize_data: bool = True,
|
168 |
+
pipeline: List[Union[dict, Callable]] = [],
|
169 |
+
test_mode: bool = False,
|
170 |
+
lazy_init: bool = False,
|
171 |
+
max_refetch: int = 1000,
|
172 |
+
backend_args: dict = None,
|
173 |
+
**kwargs) -> None:
|
174 |
+
super().__init__(
|
175 |
+
ann_file=ann_file,
|
176 |
+
metainfo=metainfo,
|
177 |
+
data_root=data_root,
|
178 |
+
data_prefix=data_prefix,
|
179 |
+
filter_cfg=filter_cfg,
|
180 |
+
indices=indices,
|
181 |
+
serialize_data=serialize_data,
|
182 |
+
pipeline=pipeline,
|
183 |
+
test_mode=test_mode,
|
184 |
+
lazy_init=lazy_init,
|
185 |
+
max_refetch=max_refetch,
|
186 |
+
backend_args=backend_args,
|
187 |
+
**kwargs)
|
188 |
+
|
189 |
+
def parse_data_info(self, raw_data_info: dict) -> dict:
|
190 |
+
"""Parse raw annotation to target format.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
raw_data_info (dict): Raw data information load from ``ann_file``.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
dict: Parsed annotation.
|
197 |
+
"""
|
198 |
+
img_info = raw_data_info['raw_img_info']
|
199 |
+
ann_info = raw_data_info['raw_ann_info']
|
200 |
+
# filter out unmatched annotations which have
|
201 |
+
# same segment_id but belong to other image
|
202 |
+
ann_info = [
|
203 |
+
ann for ann in ann_info if ann['image_id'] == img_info['img_id']
|
204 |
+
]
|
205 |
+
data_info = {}
|
206 |
+
|
207 |
+
img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
|
208 |
+
if self.data_prefix.get('seg', None):
|
209 |
+
seg_map_path = osp.join(
|
210 |
+
self.data_prefix['seg'],
|
211 |
+
img_info['file_name'].replace('jpg', 'png'))
|
212 |
+
else:
|
213 |
+
seg_map_path = None
|
214 |
+
data_info['img_path'] = img_path
|
215 |
+
data_info['img_id'] = img_info['img_id']
|
216 |
+
data_info['seg_map_path'] = seg_map_path
|
217 |
+
data_info['height'] = img_info['height']
|
218 |
+
data_info['width'] = img_info['width']
|
219 |
+
|
220 |
+
instances = []
|
221 |
+
segments_info = []
|
222 |
+
for ann in ann_info:
|
223 |
+
instance = {}
|
224 |
+
x1, y1, w, h = ann['bbox']
|
225 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
226 |
+
continue
|
227 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
228 |
+
category_id = ann['category_id']
|
229 |
+
contiguous_cat_id = self.cat2label[category_id]
|
230 |
+
|
231 |
+
is_thing = self.coco.load_cats(ids=category_id)[0]['isthing']
|
232 |
+
if is_thing:
|
233 |
+
is_crowd = ann.get('iscrowd', False)
|
234 |
+
instance['bbox'] = bbox
|
235 |
+
instance['bbox_label'] = contiguous_cat_id
|
236 |
+
if not is_crowd:
|
237 |
+
instance['ignore_flag'] = 0
|
238 |
+
else:
|
239 |
+
instance['ignore_flag'] = 1
|
240 |
+
is_thing = False
|
241 |
+
|
242 |
+
segment_info = {
|
243 |
+
'id': ann['id'],
|
244 |
+
'category': contiguous_cat_id,
|
245 |
+
'is_thing': is_thing
|
246 |
+
}
|
247 |
+
segments_info.append(segment_info)
|
248 |
+
if len(instance) > 0 and is_thing:
|
249 |
+
instances.append(instance)
|
250 |
+
data_info['instances'] = instances
|
251 |
+
data_info['segments_info'] = segments_info
|
252 |
+
return data_info
|
253 |
+
|
254 |
+
def filter_data(self) -> List[dict]:
|
255 |
+
"""Filter images too small or without ground truth.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
List[dict]: ``self.data_list`` after filtering.
|
259 |
+
"""
|
260 |
+
if self.test_mode:
|
261 |
+
return self.data_list
|
262 |
+
|
263 |
+
if self.filter_cfg is None:
|
264 |
+
return self.data_list
|
265 |
+
|
266 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
267 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
268 |
+
|
269 |
+
ids_with_ann = set()
|
270 |
+
# check whether images have legal thing annotations.
|
271 |
+
for data_info in self.data_list:
|
272 |
+
for segment_info in data_info['segments_info']:
|
273 |
+
if not segment_info['is_thing']:
|
274 |
+
continue
|
275 |
+
ids_with_ann.add(data_info['img_id'])
|
276 |
+
|
277 |
+
valid_data_list = []
|
278 |
+
for data_info in self.data_list:
|
279 |
+
img_id = data_info['img_id']
|
280 |
+
width = data_info['width']
|
281 |
+
height = data_info['height']
|
282 |
+
if filter_empty_gt and img_id not in ids_with_ann:
|
283 |
+
continue
|
284 |
+
if min(width, height) >= min_size:
|
285 |
+
valid_data_list.append(data_info)
|
286 |
+
|
287 |
+
return valid_data_list
|
mmdet/datasets/crowdhuman.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os.path as osp
|
5 |
+
import warnings
|
6 |
+
from typing import List, Union
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
from mmengine.dist import get_rank
|
10 |
+
from mmengine.fileio import dump, get, get_text, load
|
11 |
+
from mmengine.logging import print_log
|
12 |
+
from mmengine.utils import ProgressBar
|
13 |
+
|
14 |
+
from mmdet.registry import DATASETS
|
15 |
+
from .base_det_dataset import BaseDetDataset
|
16 |
+
|
17 |
+
|
18 |
+
@DATASETS.register_module()
|
19 |
+
class CrowdHumanDataset(BaseDetDataset):
|
20 |
+
r"""Dataset for CrowdHuman.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
data_root (str): The root directory for
|
24 |
+
``data_prefix`` and ``ann_file``.
|
25 |
+
ann_file (str): Annotation file path.
|
26 |
+
extra_ann_file (str | optional):The path of extra image metas
|
27 |
+
for CrowdHuman. It can be created by CrowdHumanDataset
|
28 |
+
automatically or by tools/misc/get_crowdhuman_id_hw.py
|
29 |
+
manually. Defaults to None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
METAINFO = {
|
33 |
+
'classes': ('person', ),
|
34 |
+
# palette is a list of color tuples, which is used for visualization.
|
35 |
+
'palette': [(220, 20, 60)]
|
36 |
+
}
|
37 |
+
|
38 |
+
def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs):
|
39 |
+
# extra_ann_file record the size of each image. This file is
|
40 |
+
# automatically created when you first load the CrowdHuman
|
41 |
+
# dataset by mmdet.
|
42 |
+
if extra_ann_file is not None:
|
43 |
+
self.extra_ann_exist = True
|
44 |
+
self.extra_anns = load(extra_ann_file)
|
45 |
+
else:
|
46 |
+
ann_file_name = osp.basename(ann_file)
|
47 |
+
if 'train' in ann_file_name:
|
48 |
+
self.extra_ann_file = osp.join(data_root, 'id_hw_train.json')
|
49 |
+
elif 'val' in ann_file_name:
|
50 |
+
self.extra_ann_file = osp.join(data_root, 'id_hw_val.json')
|
51 |
+
self.extra_ann_exist = False
|
52 |
+
if not osp.isfile(self.extra_ann_file):
|
53 |
+
print_log(
|
54 |
+
'extra_ann_file does not exist, prepare to collect '
|
55 |
+
'image height and width...',
|
56 |
+
level=logging.INFO)
|
57 |
+
self.extra_anns = {}
|
58 |
+
else:
|
59 |
+
self.extra_ann_exist = True
|
60 |
+
self.extra_anns = load(self.extra_ann_file)
|
61 |
+
super().__init__(data_root=data_root, ann_file=ann_file, **kwargs)
|
62 |
+
|
63 |
+
def load_data_list(self) -> List[dict]:
|
64 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
List[dict]: A list of annotation.
|
68 |
+
""" # noqa: E501
|
69 |
+
anno_strs = get_text(
|
70 |
+
self.ann_file, backend_args=self.backend_args).strip().split('\n')
|
71 |
+
print_log('loading CrowdHuman annotation...', level=logging.INFO)
|
72 |
+
data_list = []
|
73 |
+
prog_bar = ProgressBar(len(anno_strs))
|
74 |
+
for i, anno_str in enumerate(anno_strs):
|
75 |
+
anno_dict = json.loads(anno_str)
|
76 |
+
parsed_data_info = self.parse_data_info(anno_dict)
|
77 |
+
data_list.append(parsed_data_info)
|
78 |
+
prog_bar.update()
|
79 |
+
if not self.extra_ann_exist and get_rank() == 0:
|
80 |
+
# TODO: support file client
|
81 |
+
try:
|
82 |
+
dump(self.extra_anns, self.extra_ann_file, file_format='json')
|
83 |
+
except: # noqa
|
84 |
+
warnings.warn(
|
85 |
+
'Cache files can not be saved automatically! To speed up'
|
86 |
+
'loading the dataset, please manually generate the cache'
|
87 |
+
' file by file tools/misc/get_crowdhuman_id_hw.py')
|
88 |
+
|
89 |
+
print_log(
|
90 |
+
f'\nsave extra_ann_file in {self.data_root}',
|
91 |
+
level=logging.INFO)
|
92 |
+
|
93 |
+
del self.extra_anns
|
94 |
+
print_log('\nDone', level=logging.INFO)
|
95 |
+
return data_list
|
96 |
+
|
97 |
+
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
|
98 |
+
"""Parse raw annotation to target format.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
raw_data_info (dict): Raw data information load from ``ann_file``
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Union[dict, List[dict]]: Parsed annotation.
|
105 |
+
"""
|
106 |
+
data_info = {}
|
107 |
+
img_path = osp.join(self.data_prefix['img'],
|
108 |
+
f"{raw_data_info['ID']}.jpg")
|
109 |
+
data_info['img_path'] = img_path
|
110 |
+
data_info['img_id'] = raw_data_info['ID']
|
111 |
+
|
112 |
+
if not self.extra_ann_exist:
|
113 |
+
img_bytes = get(img_path, backend_args=self.backend_args)
|
114 |
+
img = mmcv.imfrombytes(img_bytes, backend='cv2')
|
115 |
+
data_info['height'], data_info['width'] = img.shape[:2]
|
116 |
+
self.extra_anns[raw_data_info['ID']] = img.shape[:2]
|
117 |
+
del img, img_bytes
|
118 |
+
else:
|
119 |
+
data_info['height'], data_info['width'] = self.extra_anns[
|
120 |
+
raw_data_info['ID']]
|
121 |
+
|
122 |
+
instances = []
|
123 |
+
for i, ann in enumerate(raw_data_info['gtboxes']):
|
124 |
+
instance = {}
|
125 |
+
if ann['tag'] not in self.metainfo['classes']:
|
126 |
+
instance['bbox_label'] = -1
|
127 |
+
instance['ignore_flag'] = 1
|
128 |
+
else:
|
129 |
+
instance['bbox_label'] = self.metainfo['classes'].index(
|
130 |
+
ann['tag'])
|
131 |
+
instance['ignore_flag'] = 0
|
132 |
+
if 'extra' in ann:
|
133 |
+
if 'ignore' in ann['extra']:
|
134 |
+
if ann['extra']['ignore'] != 0:
|
135 |
+
instance['bbox_label'] = -1
|
136 |
+
instance['ignore_flag'] = 1
|
137 |
+
|
138 |
+
x1, y1, w, h = ann['fbox']
|
139 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
140 |
+
instance['bbox'] = bbox
|
141 |
+
|
142 |
+
# Record the full bbox(fbox), head bbox(hbox) and visible
|
143 |
+
# bbox(vbox) as additional information. If you need to use
|
144 |
+
# this information, you just need to design the pipeline
|
145 |
+
# instead of overriding the CrowdHumanDataset.
|
146 |
+
instance['fbox'] = bbox
|
147 |
+
hbox = ann['hbox']
|
148 |
+
instance['hbox'] = [
|
149 |
+
hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3]
|
150 |
+
]
|
151 |
+
vbox = ann['vbox']
|
152 |
+
instance['vbox'] = [
|
153 |
+
vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3]
|
154 |
+
]
|
155 |
+
|
156 |
+
instances.append(instance)
|
157 |
+
|
158 |
+
data_info['instances'] = instances
|
159 |
+
return data_info
|
mmdet/datasets/dataset_wrappers.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import collections
|
3 |
+
import copy
|
4 |
+
from typing import Sequence, Union
|
5 |
+
|
6 |
+
from mmengine.dataset import BaseDataset, force_full_init
|
7 |
+
|
8 |
+
from mmdet.registry import DATASETS, TRANSFORMS
|
9 |
+
|
10 |
+
|
11 |
+
@DATASETS.register_module()
|
12 |
+
class MultiImageMixDataset:
|
13 |
+
"""A wrapper of multiple images mixed dataset.
|
14 |
+
|
15 |
+
Suitable for training on multiple images mixed data augmentation like
|
16 |
+
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
17 |
+
the `get_indexes` method needs to be provided to obtain the image
|
18 |
+
indexes, and you can set `skip_flags` to change the pipeline running
|
19 |
+
process. At the same time, we provide the `dynamic_scale` parameter
|
20 |
+
to dynamically change the output image size.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
dataset (:obj:`CustomDataset`): The dataset to be mixed.
|
24 |
+
pipeline (Sequence[dict]): Sequence of transform object or
|
25 |
+
config dict to be composed.
|
26 |
+
dynamic_scale (tuple[int], optional): The image scale can be changed
|
27 |
+
dynamically. Default to None. It is deprecated.
|
28 |
+
skip_type_keys (list[str], optional): Sequence of type string to
|
29 |
+
be skip pipeline. Default to None.
|
30 |
+
max_refetch (int): The maximum number of retry iterations for getting
|
31 |
+
valid results from the pipeline. If the number of iterations is
|
32 |
+
greater than `max_refetch`, but results is still None, then the
|
33 |
+
iteration is terminated and raise the error. Default: 15.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
dataset: Union[BaseDataset, dict],
|
38 |
+
pipeline: Sequence[str],
|
39 |
+
skip_type_keys: Union[Sequence[str], None] = None,
|
40 |
+
max_refetch: int = 15,
|
41 |
+
lazy_init: bool = False) -> None:
|
42 |
+
assert isinstance(pipeline, collections.abc.Sequence)
|
43 |
+
if skip_type_keys is not None:
|
44 |
+
assert all([
|
45 |
+
isinstance(skip_type_key, str)
|
46 |
+
for skip_type_key in skip_type_keys
|
47 |
+
])
|
48 |
+
self._skip_type_keys = skip_type_keys
|
49 |
+
|
50 |
+
self.pipeline = []
|
51 |
+
self.pipeline_types = []
|
52 |
+
for transform in pipeline:
|
53 |
+
if isinstance(transform, dict):
|
54 |
+
self.pipeline_types.append(transform['type'])
|
55 |
+
transform = TRANSFORMS.build(transform)
|
56 |
+
self.pipeline.append(transform)
|
57 |
+
else:
|
58 |
+
raise TypeError('pipeline must be a dict')
|
59 |
+
|
60 |
+
self.dataset: BaseDataset
|
61 |
+
if isinstance(dataset, dict):
|
62 |
+
self.dataset = DATASETS.build(dataset)
|
63 |
+
elif isinstance(dataset, BaseDataset):
|
64 |
+
self.dataset = dataset
|
65 |
+
else:
|
66 |
+
raise TypeError(
|
67 |
+
'elements in datasets sequence should be config or '
|
68 |
+
f'`BaseDataset` instance, but got {type(dataset)}')
|
69 |
+
|
70 |
+
self._metainfo = self.dataset.metainfo
|
71 |
+
if hasattr(self.dataset, 'flag'):
|
72 |
+
self.flag = self.dataset.flag
|
73 |
+
self.num_samples = len(self.dataset)
|
74 |
+
self.max_refetch = max_refetch
|
75 |
+
|
76 |
+
self._fully_initialized = False
|
77 |
+
if not lazy_init:
|
78 |
+
self.full_init()
|
79 |
+
|
80 |
+
@property
|
81 |
+
def metainfo(self) -> dict:
|
82 |
+
"""Get the meta information of the multi-image-mixed dataset.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
dict: The meta information of multi-image-mixed dataset.
|
86 |
+
"""
|
87 |
+
return copy.deepcopy(self._metainfo)
|
88 |
+
|
89 |
+
def full_init(self):
|
90 |
+
"""Loop to ``full_init`` each dataset."""
|
91 |
+
if self._fully_initialized:
|
92 |
+
return
|
93 |
+
|
94 |
+
self.dataset.full_init()
|
95 |
+
self._ori_len = len(self.dataset)
|
96 |
+
self._fully_initialized = True
|
97 |
+
|
98 |
+
@force_full_init
|
99 |
+
def get_data_info(self, idx: int) -> dict:
|
100 |
+
"""Get annotation by index.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
idx (int): Global index of ``ConcatDataset``.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
dict: The idx-th annotation of the datasets.
|
107 |
+
"""
|
108 |
+
return self.dataset.get_data_info(idx)
|
109 |
+
|
110 |
+
@force_full_init
|
111 |
+
def __len__(self):
|
112 |
+
return self.num_samples
|
113 |
+
|
114 |
+
def __getitem__(self, idx):
|
115 |
+
results = copy.deepcopy(self.dataset[idx])
|
116 |
+
for (transform, transform_type) in zip(self.pipeline,
|
117 |
+
self.pipeline_types):
|
118 |
+
if self._skip_type_keys is not None and \
|
119 |
+
transform_type in self._skip_type_keys:
|
120 |
+
continue
|
121 |
+
|
122 |
+
if hasattr(transform, 'get_indexes'):
|
123 |
+
for i in range(self.max_refetch):
|
124 |
+
# Make sure the results passed the loading pipeline
|
125 |
+
# of the original dataset is not None.
|
126 |
+
indexes = transform.get_indexes(self.dataset)
|
127 |
+
if not isinstance(indexes, collections.abc.Sequence):
|
128 |
+
indexes = [indexes]
|
129 |
+
mix_results = [
|
130 |
+
copy.deepcopy(self.dataset[index]) for index in indexes
|
131 |
+
]
|
132 |
+
if None not in mix_results:
|
133 |
+
results['mix_results'] = mix_results
|
134 |
+
break
|
135 |
+
else:
|
136 |
+
raise RuntimeError(
|
137 |
+
'The loading pipeline of the original dataset'
|
138 |
+
' always return None. Please check the correctness '
|
139 |
+
'of the dataset and its pipeline.')
|
140 |
+
|
141 |
+
for i in range(self.max_refetch):
|
142 |
+
# To confirm the results passed the training pipeline
|
143 |
+
# of the wrapper is not None.
|
144 |
+
updated_results = transform(copy.deepcopy(results))
|
145 |
+
if updated_results is not None:
|
146 |
+
results = updated_results
|
147 |
+
break
|
148 |
+
else:
|
149 |
+
raise RuntimeError(
|
150 |
+
'The training pipeline of the dataset wrapper'
|
151 |
+
' always return None.Please check the correctness '
|
152 |
+
'of the dataset and its pipeline.')
|
153 |
+
|
154 |
+
if 'mix_results' in results:
|
155 |
+
results.pop('mix_results')
|
156 |
+
|
157 |
+
return results
|
158 |
+
|
159 |
+
def update_skip_type_keys(self, skip_type_keys):
|
160 |
+
"""Update skip_type_keys. It is called by an external hook.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
skip_type_keys (list[str], optional): Sequence of type
|
164 |
+
string to be skip pipeline.
|
165 |
+
"""
|
166 |
+
assert all([
|
167 |
+
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
168 |
+
])
|
169 |
+
self._skip_type_keys = skip_type_keys
|
mmdet/datasets/deepfashion.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmdet.registry import DATASETS
|
3 |
+
from .coco import CocoDataset
|
4 |
+
|
5 |
+
|
6 |
+
@DATASETS.register_module()
|
7 |
+
class DeepFashionDataset(CocoDataset):
|
8 |
+
"""Dataset for DeepFashion."""
|
9 |
+
|
10 |
+
METAINFO = {
|
11 |
+
'classes': ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants',
|
12 |
+
'bag', 'neckwear', 'headwear', 'eyeglass', 'belt',
|
13 |
+
'footwear', 'hair', 'skin', 'face'),
|
14 |
+
# palette is a list of color tuples, which is used for visualization.
|
15 |
+
'palette': [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64),
|
16 |
+
(0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96),
|
17 |
+
(128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192),
|
18 |
+
(128, 0, 96), (128, 0, 192), (0, 32, 192)]
|
19 |
+
}
|
mmdet/datasets/lvis.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from mmengine.fileio import get_local_path
|
7 |
+
|
8 |
+
from mmdet.registry import DATASETS
|
9 |
+
from .coco import CocoDataset
|
10 |
+
|
11 |
+
|
12 |
+
@DATASETS.register_module()
|
13 |
+
class LVISV05Dataset(CocoDataset):
|
14 |
+
"""LVIS v0.5 dataset for detection."""
|
15 |
+
|
16 |
+
METAINFO = {
|
17 |
+
'classes':
|
18 |
+
('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
|
19 |
+
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
|
20 |
+
'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
|
21 |
+
'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
|
22 |
+
'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
|
23 |
+
'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
|
24 |
+
'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
|
25 |
+
'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
|
26 |
+
'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
|
27 |
+
'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
|
28 |
+
'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
|
29 |
+
'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
|
30 |
+
'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball',
|
31 |
+
'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
|
32 |
+
'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle',
|
33 |
+
'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle',
|
34 |
+
'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder',
|
35 |
+
'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage',
|
36 |
+
'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)',
|
37 |
+
'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer',
|
38 |
+
'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard',
|
39 |
+
'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt',
|
40 |
+
'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet',
|
41 |
+
'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener',
|
42 |
+
'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie',
|
43 |
+
'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
|
44 |
+
'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
|
45 |
+
'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
|
46 |
+
'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
|
47 |
+
'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
|
48 |
+
'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
|
49 |
+
'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
|
50 |
+
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
|
51 |
+
'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
|
52 |
+
'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
|
53 |
+
'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
|
54 |
+
'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
|
55 |
+
'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
|
56 |
+
'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
|
57 |
+
'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
|
58 |
+
'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
|
59 |
+
'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
|
60 |
+
'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
|
61 |
+
'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
|
62 |
+
'celery', 'cellular_telephone', 'chain_mail', 'chair',
|
63 |
+
'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook',
|
64 |
+
'checkerboard', 'cherry', 'chessboard',
|
65 |
+
'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire',
|
66 |
+
'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware',
|
67 |
+
'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
|
68 |
+
'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
|
69 |
+
'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
|
70 |
+
'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
|
71 |
+
'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard',
|
72 |
+
'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag',
|
73 |
+
'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut',
|
74 |
+
'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
|
75 |
+
'coin', 'colander', 'coleslaw', 'coloring_material',
|
76 |
+
'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard',
|
77 |
+
'concrete_mixer', 'cone', 'control', 'convertible_(automobile)',
|
78 |
+
'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil',
|
79 |
+
'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew',
|
80 |
+
'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
|
81 |
+
'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
|
82 |
+
'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
|
83 |
+
'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
|
84 |
+
'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
|
85 |
+
'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
|
86 |
+
'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
|
87 |
+
'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
|
88 |
+
'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
|
89 |
+
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
|
90 |
+
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
|
91 |
+
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
|
92 |
+
'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
|
93 |
+
'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
|
94 |
+
'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
|
95 |
+
'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
|
96 |
+
'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
|
97 |
+
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
|
98 |
+
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
|
99 |
+
'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
|
100 |
+
'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
|
101 |
+
'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
|
102 |
+
'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
|
103 |
+
'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
|
104 |
+
'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)',
|
105 |
+
'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose',
|
106 |
+
'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl',
|
107 |
+
'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo',
|
108 |
+
'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
|
109 |
+
'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
|
110 |
+
'folding_chair', 'food_processor', 'football_(American)',
|
111 |
+
'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
|
112 |
+
'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
|
113 |
+
'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag',
|
114 |
+
'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle',
|
115 |
+
'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
|
116 |
+
'gift_wrap', 'ginger', 'giraffe', 'cincture',
|
117 |
+
'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
|
118 |
+
'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
|
119 |
+
'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
|
120 |
+
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
|
121 |
+
'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
|
122 |
+
'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
|
123 |
+
'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
|
124 |
+
'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
|
125 |
+
'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
|
126 |
+
'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
|
127 |
+
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
|
128 |
+
'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
|
129 |
+
'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
|
130 |
+
'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
|
131 |
+
'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
|
132 |
+
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
|
133 |
+
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
|
134 |
+
'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
|
135 |
+
'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
|
136 |
+
'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
|
137 |
+
'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
|
138 |
+
'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
|
139 |
+
'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
|
140 |
+
'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
|
141 |
+
'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
|
142 |
+
'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
|
143 |
+
'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
|
144 |
+
'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
|
145 |
+
'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
|
146 |
+
'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
|
147 |
+
'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
|
148 |
+
'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
|
149 |
+
'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
|
150 |
+
'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
|
151 |
+
'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
|
152 |
+
'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
|
153 |
+
'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
|
154 |
+
'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
|
155 |
+
'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
|
156 |
+
'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
|
157 |
+
'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
|
158 |
+
'mound_(baseball)', 'mouse_(animal_rodent)',
|
159 |
+
'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
|
160 |
+
'music_stool', 'musical_instrument', 'nailfile', 'nameplate',
|
161 |
+
'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest',
|
162 |
+
'newsstand', 'nightshirt', 'nosebag_(for_animals)',
|
163 |
+
'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
|
164 |
+
'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
|
165 |
+
'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano',
|
166 |
+
'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet',
|
167 |
+
'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush',
|
168 |
+
'painting', 'pajamas', 'palette', 'pan_(for_cooking)',
|
169 |
+
'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya',
|
170 |
+
'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
|
171 |
+
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
|
172 |
+
'parchment', 'parka', 'parking_meter', 'parrot',
|
173 |
+
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
|
174 |
+
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
|
175 |
+
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
|
176 |
+
'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
|
177 |
+
'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
|
178 |
+
'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
|
179 |
+
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
|
180 |
+
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
|
181 |
+
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
|
182 |
+
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
|
183 |
+
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
|
184 |
+
'plate', 'platter', 'playing_card', 'playpen', 'pliers',
|
185 |
+
'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
|
186 |
+
'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
|
187 |
+
'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
|
188 |
+
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
|
189 |
+
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
|
190 |
+
'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
|
191 |
+
'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
|
192 |
+
'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
|
193 |
+
'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
|
194 |
+
'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
|
195 |
+
'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
|
196 |
+
'recliner', 'record_player', 'red_cabbage', 'reflector',
|
197 |
+
'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
|
198 |
+
'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
|
199 |
+
'Rollerblade', 'rolling_pin', 'root_beer',
|
200 |
+
'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
|
201 |
+
'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
|
202 |
+
'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
|
203 |
+
'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
|
204 |
+
'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
|
205 |
+
'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
|
206 |
+
'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
|
207 |
+
'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
|
208 |
+
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
|
209 |
+
'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
|
210 |
+
'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
|
211 |
+
'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
|
212 |
+
'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag',
|
213 |
+
'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag',
|
214 |
+
'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)',
|
215 |
+
'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski',
|
216 |
+
'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag',
|
217 |
+
'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake',
|
218 |
+
'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock',
|
219 |
+
'soda_fountain', 'carbonated_water', 'sofa', 'softball',
|
220 |
+
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
|
221 |
+
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
|
222 |
+
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge',
|
223 |
+
'spoon', 'sportswear', 'spotlight', 'squirrel',
|
224 |
+
'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)',
|
225 |
+
'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)',
|
226 |
+
'steering_wheel', 'stencil', 'stepladder', 'step_stool',
|
227 |
+
'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup',
|
228 |
+
'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove',
|
229 |
+
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
|
230 |
+
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
|
231 |
+
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
|
232 |
+
'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
|
233 |
+
'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
|
234 |
+
'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
|
235 |
+
'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
|
236 |
+
'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
|
237 |
+
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
|
238 |
+
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
|
239 |
+
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
|
240 |
+
'telephone_pole', 'telephoto_lens', 'television_camera',
|
241 |
+
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
|
242 |
+
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
|
243 |
+
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
|
244 |
+
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
|
245 |
+
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
|
246 |
+
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
|
247 |
+
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
|
248 |
+
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
|
249 |
+
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
|
250 |
+
'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
|
251 |
+
'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
|
252 |
+
'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
|
253 |
+
'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
|
254 |
+
'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
|
255 |
+
'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
|
256 |
+
'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
|
257 |
+
'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
|
258 |
+
'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
|
259 |
+
'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
|
260 |
+
'water_heater', 'water_jug', 'water_gun', 'water_scooter',
|
261 |
+
'water_ski', 'water_tower', 'watering_can', 'watermelon',
|
262 |
+
'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit',
|
263 |
+
'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick',
|
264 |
+
'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
|
265 |
+
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
|
266 |
+
'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf',
|
267 |
+
'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht',
|
268 |
+
'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
|
269 |
+
'palette':
|
270 |
+
None
|
271 |
+
}
|
272 |
+
|
273 |
+
def load_data_list(self) -> List[dict]:
|
274 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
List[dict]: A list of annotation.
|
278 |
+
""" # noqa: E501
|
279 |
+
try:
|
280 |
+
import lvis
|
281 |
+
if getattr(lvis, '__version__', '0') >= '10.5.3':
|
282 |
+
warnings.warn(
|
283 |
+
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
|
284 |
+
UserWarning)
|
285 |
+
from lvis import LVIS
|
286 |
+
except ImportError:
|
287 |
+
raise ImportError(
|
288 |
+
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
|
289 |
+
)
|
290 |
+
with get_local_path(
|
291 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
292 |
+
self.lvis = LVIS(local_path)
|
293 |
+
self.cat_ids = self.lvis.get_cat_ids()
|
294 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
295 |
+
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
|
296 |
+
|
297 |
+
img_ids = self.lvis.get_img_ids()
|
298 |
+
data_list = []
|
299 |
+
total_ann_ids = []
|
300 |
+
for img_id in img_ids:
|
301 |
+
raw_img_info = self.lvis.load_imgs([img_id])[0]
|
302 |
+
raw_img_info['img_id'] = img_id
|
303 |
+
if raw_img_info['file_name'].startswith('COCO'):
|
304 |
+
# Convert form the COCO 2014 file naming convention of
|
305 |
+
# COCO_[train/val/test]2014_000000000000.jpg to the 2017
|
306 |
+
# naming convention of 000000000000.jpg
|
307 |
+
# (LVIS v1 will fix this naming issue)
|
308 |
+
raw_img_info['file_name'] = raw_img_info['file_name'][-16:]
|
309 |
+
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
|
310 |
+
raw_ann_info = self.lvis.load_anns(ann_ids)
|
311 |
+
total_ann_ids.extend(ann_ids)
|
312 |
+
|
313 |
+
parsed_data_info = self.parse_data_info({
|
314 |
+
'raw_ann_info':
|
315 |
+
raw_ann_info,
|
316 |
+
'raw_img_info':
|
317 |
+
raw_img_info
|
318 |
+
})
|
319 |
+
data_list.append(parsed_data_info)
|
320 |
+
if self.ANN_ID_UNIQUE:
|
321 |
+
assert len(set(total_ann_ids)) == len(
|
322 |
+
total_ann_ids
|
323 |
+
), f"Annotation ids in '{self.ann_file}' are not unique!"
|
324 |
+
|
325 |
+
del self.lvis
|
326 |
+
|
327 |
+
return data_list
|
328 |
+
|
329 |
+
|
330 |
+
LVISDataset = LVISV05Dataset
|
331 |
+
DATASETS.register_module(name='LVISDataset', module=LVISDataset)
|
332 |
+
|
333 |
+
|
334 |
+
@DATASETS.register_module()
|
335 |
+
class LVISV1Dataset(LVISDataset):
|
336 |
+
"""LVIS v1 dataset for detection."""
|
337 |
+
|
338 |
+
METAINFO = {
|
339 |
+
'classes':
|
340 |
+
('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
|
341 |
+
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
|
342 |
+
'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
|
343 |
+
'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
|
344 |
+
'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
|
345 |
+
'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
|
346 |
+
'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
|
347 |
+
'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
|
348 |
+
'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
|
349 |
+
'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
|
350 |
+
'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
|
351 |
+
'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
|
352 |
+
'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
|
353 |
+
'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
|
354 |
+
'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
|
355 |
+
'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
|
356 |
+
'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
|
357 |
+
'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
|
358 |
+
'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
|
359 |
+
'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
|
360 |
+
'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
|
361 |
+
'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
|
362 |
+
'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
|
363 |
+
'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
|
364 |
+
'bottle_opener', 'bouquet', 'bow_(weapon)',
|
365 |
+
'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
|
366 |
+
'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
|
367 |
+
'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
|
368 |
+
'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
|
369 |
+
'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
|
370 |
+
'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
|
371 |
+
'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
|
372 |
+
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
|
373 |
+
'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
|
374 |
+
'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
|
375 |
+
'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
|
376 |
+
'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
|
377 |
+
'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
|
378 |
+
'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
|
379 |
+
'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
|
380 |
+
'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
|
381 |
+
'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
|
382 |
+
'cash_register', 'casserole', 'cassette', 'cast', 'cat',
|
383 |
+
'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
|
384 |
+
'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
|
385 |
+
'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
|
386 |
+
'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
|
387 |
+
'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
|
388 |
+
'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
|
389 |
+
'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
|
390 |
+
'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
|
391 |
+
'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
|
392 |
+
'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
|
393 |
+
'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
|
394 |
+
'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
|
395 |
+
'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
|
396 |
+
'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
|
397 |
+
'coin', 'colander', 'coleslaw', 'coloring_material',
|
398 |
+
'combination_lock', 'pacifier', 'comic_book', 'compass',
|
399 |
+
'computer_keyboard', 'condiment', 'cone', 'control',
|
400 |
+
'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
|
401 |
+
'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
|
402 |
+
'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
|
403 |
+
'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
|
404 |
+
'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
|
405 |
+
'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
|
406 |
+
'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
|
407 |
+
'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
|
408 |
+
'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
|
409 |
+
'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
|
410 |
+
'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
|
411 |
+
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
|
412 |
+
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
|
413 |
+
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
|
414 |
+
'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
|
415 |
+
'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
|
416 |
+
'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
|
417 |
+
'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
|
418 |
+
'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
|
419 |
+
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
|
420 |
+
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
|
421 |
+
'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
|
422 |
+
'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
|
423 |
+
'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
|
424 |
+
'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
|
425 |
+
'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
|
426 |
+
'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
|
427 |
+
'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
|
428 |
+
'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
|
429 |
+
'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
|
430 |
+
'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
|
431 |
+
'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
|
432 |
+
'food_processor', 'football_(American)', 'football_helmet',
|
433 |
+
'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
|
434 |
+
'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
|
435 |
+
'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
|
436 |
+
'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
|
437 |
+
'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
|
438 |
+
'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
|
439 |
+
'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
|
440 |
+
'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
|
441 |
+
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
|
442 |
+
'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
|
443 |
+
'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
|
444 |
+
'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
|
445 |
+
'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
|
446 |
+
'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
|
447 |
+
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
|
448 |
+
'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
|
449 |
+
'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
|
450 |
+
'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
|
451 |
+
'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
|
452 |
+
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
|
453 |
+
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
|
454 |
+
'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
|
455 |
+
'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
|
456 |
+
'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
|
457 |
+
'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
|
458 |
+
'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
|
459 |
+
'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
|
460 |
+
'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
|
461 |
+
'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
|
462 |
+
'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
|
463 |
+
'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
|
464 |
+
'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
|
465 |
+
'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
|
466 |
+
'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
|
467 |
+
'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
|
468 |
+
'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
|
469 |
+
'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
|
470 |
+
'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
|
471 |
+
'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
|
472 |
+
'meatball', 'medicine', 'melon', 'microphone', 'microscope',
|
473 |
+
'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
|
474 |
+
'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
|
475 |
+
'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
|
476 |
+
'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
|
477 |
+
'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
|
478 |
+
'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
|
479 |
+
'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
|
480 |
+
'nest', 'newspaper', 'newsstand', 'nightshirt',
|
481 |
+
'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
|
482 |
+
'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
|
483 |
+
'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
|
484 |
+
'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
|
485 |
+
'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
|
486 |
+
'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
|
487 |
+
'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
|
488 |
+
'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
|
489 |
+
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
|
490 |
+
'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
|
491 |
+
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
|
492 |
+
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
|
493 |
+
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
|
494 |
+
'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
|
495 |
+
'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
|
496 |
+
'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
|
497 |
+
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
|
498 |
+
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
|
499 |
+
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
|
500 |
+
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
|
501 |
+
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
|
502 |
+
'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
|
503 |
+
'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
|
504 |
+
'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
|
505 |
+
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
|
506 |
+
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
|
507 |
+
'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
|
508 |
+
'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
|
509 |
+
'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
|
510 |
+
'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
|
511 |
+
'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
|
512 |
+
'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
|
513 |
+
'recliner', 'record_player', 'reflector', 'remote_control',
|
514 |
+
'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
|
515 |
+
'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
|
516 |
+
'rolling_pin', 'root_beer', 'router_(computer_equipment)',
|
517 |
+
'rubber_band', 'runner_(carpet)', 'plastic_bag',
|
518 |
+
'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
|
519 |
+
'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
|
520 |
+
'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
|
521 |
+
'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
|
522 |
+
'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
|
523 |
+
'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
|
524 |
+
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
|
525 |
+
'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
|
526 |
+
'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
|
527 |
+
'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
|
528 |
+
'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
|
529 |
+
'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
|
530 |
+
'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
|
531 |
+
'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
|
532 |
+
'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
|
533 |
+
'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
|
534 |
+
'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
|
535 |
+
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
|
536 |
+
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
|
537 |
+
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
|
538 |
+
'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
|
539 |
+
'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
|
540 |
+
'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
|
541 |
+
'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
|
542 |
+
'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
|
543 |
+
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
|
544 |
+
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
|
545 |
+
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
|
546 |
+
'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
|
547 |
+
'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
|
548 |
+
'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
|
549 |
+
'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
|
550 |
+
'tambourine', 'army_tank', 'tank_(storage_vessel)',
|
551 |
+
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
|
552 |
+
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
|
553 |
+
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
|
554 |
+
'telephone_pole', 'telephoto_lens', 'television_camera',
|
555 |
+
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
|
556 |
+
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
|
557 |
+
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
|
558 |
+
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
|
559 |
+
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
|
560 |
+
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
|
561 |
+
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
|
562 |
+
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
|
563 |
+
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
|
564 |
+
'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
|
565 |
+
'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
|
566 |
+
'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
|
567 |
+
'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
|
568 |
+
'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
|
569 |
+
'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
|
570 |
+
'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
|
571 |
+
'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
|
572 |
+
'washbasin', 'automatic_washer', 'watch', 'water_bottle',
|
573 |
+
'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
|
574 |
+
'water_gun', 'water_scooter', 'water_ski', 'water_tower',
|
575 |
+
'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
|
576 |
+
'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
|
577 |
+
'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
|
578 |
+
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
|
579 |
+
'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
|
580 |
+
'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
|
581 |
+
'yoke_(animal_equipment)', 'zebra', 'zucchini'),
|
582 |
+
'palette':
|
583 |
+
None
|
584 |
+
}
|
585 |
+
|
586 |
+
def load_data_list(self) -> List[dict]:
|
587 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
588 |
+
|
589 |
+
Returns:
|
590 |
+
List[dict]: A list of annotation.
|
591 |
+
""" # noqa: E501
|
592 |
+
try:
|
593 |
+
import lvis
|
594 |
+
if getattr(lvis, '__version__', '0') >= '10.5.3':
|
595 |
+
warnings.warn(
|
596 |
+
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
|
597 |
+
UserWarning)
|
598 |
+
from lvis import LVIS
|
599 |
+
except ImportError:
|
600 |
+
raise ImportError(
|
601 |
+
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
|
602 |
+
)
|
603 |
+
with get_local_path(
|
604 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
605 |
+
self.lvis = LVIS(local_path)
|
606 |
+
self.cat_ids = self.lvis.get_cat_ids()
|
607 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
608 |
+
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
|
609 |
+
|
610 |
+
img_ids = self.lvis.get_img_ids()
|
611 |
+
data_list = []
|
612 |
+
total_ann_ids = []
|
613 |
+
for img_id in img_ids:
|
614 |
+
raw_img_info = self.lvis.load_imgs([img_id])[0]
|
615 |
+
raw_img_info['img_id'] = img_id
|
616 |
+
# coco_url is used in LVISv1 instead of file_name
|
617 |
+
# e.g. http://images.cocodataset.org/train2017/000000391895.jpg
|
618 |
+
# train/val split in specified in url
|
619 |
+
raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
|
620 |
+
'http://images.cocodataset.org/', '')
|
621 |
+
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
|
622 |
+
raw_ann_info = self.lvis.load_anns(ann_ids)
|
623 |
+
total_ann_ids.extend(ann_ids)
|
624 |
+
parsed_data_info = self.parse_data_info({
|
625 |
+
'raw_ann_info':
|
626 |
+
raw_ann_info,
|
627 |
+
'raw_img_info':
|
628 |
+
raw_img_info
|
629 |
+
})
|
630 |
+
data_list.append(parsed_data_info)
|
631 |
+
if self.ANN_ID_UNIQUE:
|
632 |
+
assert len(set(total_ann_ids)) == len(
|
633 |
+
total_ann_ids
|
634 |
+
), f"Annotation ids in '{self.ann_file}' are not unique!"
|
635 |
+
|
636 |
+
del self.lvis
|
637 |
+
|
638 |
+
return data_list
|
mmdet/datasets/objects365.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import os.path as osp
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from mmengine.fileio import get_local_path
|
7 |
+
|
8 |
+
from mmdet.registry import DATASETS
|
9 |
+
from .api_wrappers import COCO
|
10 |
+
from .coco import CocoDataset
|
11 |
+
|
12 |
+
# images exist in annotations but not in image folder.
|
13 |
+
objv2_ignore_list = [
|
14 |
+
osp.join('patch16', 'objects365_v2_00908726.jpg'),
|
15 |
+
osp.join('patch6', 'objects365_v1_00320532.jpg'),
|
16 |
+
osp.join('patch6', 'objects365_v1_00320534.jpg'),
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
@DATASETS.register_module()
|
21 |
+
class Objects365V1Dataset(CocoDataset):
|
22 |
+
"""Objects365 v1 dataset for detection."""
|
23 |
+
|
24 |
+
METAINFO = {
|
25 |
+
'classes':
|
26 |
+
('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
|
27 |
+
'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
|
28 |
+
'handbag', 'street lights', 'book', 'plate', 'helmet',
|
29 |
+
'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet',
|
30 |
+
'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots',
|
31 |
+
'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker',
|
32 |
+
'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet',
|
33 |
+
'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table',
|
34 |
+
'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil',
|
35 |
+
'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet',
|
36 |
+
'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink',
|
37 |
+
'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake',
|
38 |
+
'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign',
|
39 |
+
'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock',
|
40 |
+
'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep',
|
41 |
+
'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey',
|
42 |
+
'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon',
|
43 |
+
'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck',
|
44 |
+
'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer',
|
45 |
+
'trolley', 'oven', 'remote', 'baseball glove', 'paper towel',
|
46 |
+
'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent',
|
47 |
+
'shampoo/shower gel', 'head phone', 'lantern', 'donut',
|
48 |
+
'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite',
|
49 |
+
'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli',
|
50 |
+
'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave',
|
51 |
+
'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver',
|
52 |
+
'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car',
|
53 |
+
'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter',
|
54 |
+
'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
|
55 |
+
'cutting/chopping board', 'tennis racket', 'candy',
|
56 |
+
'skating and skiing shoes', 'scissors', 'folder', 'baseball',
|
57 |
+
'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
|
58 |
+
'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
|
59 |
+
'american football', 'basketball', 'potato', 'paint brush', 'printer',
|
60 |
+
'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
|
61 |
+
'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
|
62 |
+
'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
|
63 |
+
'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
|
64 |
+
'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
|
65 |
+
'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
|
66 |
+
'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
|
67 |
+
'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
|
68 |
+
'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
|
69 |
+
'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
|
70 |
+
'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
|
71 |
+
'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
|
72 |
+
'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
|
73 |
+
'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
|
74 |
+
'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
|
75 |
+
'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
|
76 |
+
'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
|
77 |
+
'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste',
|
78 |
+
'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate',
|
79 |
+
'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill',
|
80 |
+
'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car',
|
81 |
+
'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier',
|
82 |
+
'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey',
|
83 |
+
'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot',
|
84 |
+
'flute', 'measuring cup', 'shark', 'steak', 'poker card',
|
85 |
+
'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab',
|
86 |
+
'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage',
|
87 |
+
'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser',
|
88 |
+
'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle',
|
89 |
+
'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board',
|
90 |
+
'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard',
|
91 |
+
'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron',
|
92 |
+
'flashlight'),
|
93 |
+
'palette':
|
94 |
+
None
|
95 |
+
}
|
96 |
+
|
97 |
+
COCOAPI = COCO
|
98 |
+
# ann_id is unique in coco dataset.
|
99 |
+
ANN_ID_UNIQUE = True
|
100 |
+
|
101 |
+
def load_data_list(self) -> List[dict]:
|
102 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
List[dict]: A list of annotation.
|
106 |
+
""" # noqa: E501
|
107 |
+
with get_local_path(
|
108 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
109 |
+
self.coco = self.COCOAPI(local_path)
|
110 |
+
|
111 |
+
# 'categories' list in objects365_train.json and objects365_val.json
|
112 |
+
# is inconsistent, need sort list(or dict) before get cat_ids.
|
113 |
+
cats = self.coco.cats
|
114 |
+
sorted_cats = {i: cats[i] for i in sorted(cats)}
|
115 |
+
self.coco.cats = sorted_cats
|
116 |
+
categories = self.coco.dataset['categories']
|
117 |
+
sorted_categories = sorted(categories, key=lambda i: i['id'])
|
118 |
+
self.coco.dataset['categories'] = sorted_categories
|
119 |
+
# The order of returned `cat_ids` will not
|
120 |
+
# change with the order of the `classes`
|
121 |
+
self.cat_ids = self.coco.get_cat_ids(
|
122 |
+
cat_names=self.metainfo['classes'])
|
123 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
124 |
+
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
|
125 |
+
|
126 |
+
img_ids = self.coco.get_img_ids()
|
127 |
+
data_list = []
|
128 |
+
total_ann_ids = []
|
129 |
+
for img_id in img_ids:
|
130 |
+
raw_img_info = self.coco.load_imgs([img_id])[0]
|
131 |
+
raw_img_info['img_id'] = img_id
|
132 |
+
|
133 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
134 |
+
raw_ann_info = self.coco.load_anns(ann_ids)
|
135 |
+
total_ann_ids.extend(ann_ids)
|
136 |
+
|
137 |
+
parsed_data_info = self.parse_data_info({
|
138 |
+
'raw_ann_info':
|
139 |
+
raw_ann_info,
|
140 |
+
'raw_img_info':
|
141 |
+
raw_img_info
|
142 |
+
})
|
143 |
+
data_list.append(parsed_data_info)
|
144 |
+
if self.ANN_ID_UNIQUE:
|
145 |
+
assert len(set(total_ann_ids)) == len(
|
146 |
+
total_ann_ids
|
147 |
+
), f"Annotation ids in '{self.ann_file}' are not unique!"
|
148 |
+
|
149 |
+
del self.coco
|
150 |
+
|
151 |
+
return data_list
|
152 |
+
|
153 |
+
|
154 |
+
@DATASETS.register_module()
|
155 |
+
class Objects365V2Dataset(CocoDataset):
|
156 |
+
"""Objects365 v2 dataset for detection."""
|
157 |
+
METAINFO = {
|
158 |
+
'classes':
|
159 |
+
('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
|
160 |
+
'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
|
161 |
+
'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
|
162 |
+
'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
|
163 |
+
'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
|
164 |
+
'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
|
165 |
+
'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
|
166 |
+
'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
|
167 |
+
'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
|
168 |
+
'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle',
|
169 |
+
'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned',
|
170 |
+
'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel',
|
171 |
+
'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed',
|
172 |
+
'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple',
|
173 |
+
'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck',
|
174 |
+
'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock',
|
175 |
+
'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
|
176 |
+
'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
|
177 |
+
'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
|
178 |
+
'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
|
179 |
+
'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
|
180 |
+
'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
|
181 |
+
'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
|
182 |
+
'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
|
183 |
+
'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
|
184 |
+
'Elephant', 'Skateboard', 'Surfboard', 'Gun',
|
185 |
+
'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
|
186 |
+
'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
|
187 |
+
'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
|
188 |
+
'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
|
189 |
+
'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
|
190 |
+
'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
|
191 |
+
'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle',
|
192 |
+
'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
|
193 |
+
'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club',
|
194 |
+
'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear',
|
195 |
+
'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong',
|
196 |
+
'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask',
|
197 |
+
'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide',
|
198 |
+
'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee',
|
199 |
+
'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
|
200 |
+
'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
|
201 |
+
'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
|
202 |
+
'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
|
203 |
+
'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
|
204 |
+
'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
|
205 |
+
'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
|
206 |
+
'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
|
207 |
+
'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
|
208 |
+
'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
|
209 |
+
'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
|
210 |
+
'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
|
211 |
+
'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
|
212 |
+
'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
|
213 |
+
'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
|
214 |
+
'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
|
215 |
+
'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker',
|
216 |
+
'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
|
217 |
+
'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin',
|
218 |
+
'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill',
|
219 |
+
'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi',
|
220 |
+
'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case',
|
221 |
+
'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
|
222 |
+
'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
|
223 |
+
'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
|
224 |
+
'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
|
225 |
+
'Table Tennis '),
|
226 |
+
'palette':
|
227 |
+
None
|
228 |
+
}
|
229 |
+
|
230 |
+
COCOAPI = COCO
|
231 |
+
# ann_id is unique in coco dataset.
|
232 |
+
ANN_ID_UNIQUE = True
|
233 |
+
|
234 |
+
def load_data_list(self) -> List[dict]:
|
235 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
List[dict]: A list of annotation.
|
239 |
+
""" # noqa: E501
|
240 |
+
with get_local_path(
|
241 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
242 |
+
self.coco = self.COCOAPI(local_path)
|
243 |
+
# The order of returned `cat_ids` will not
|
244 |
+
# change with the order of the `classes`
|
245 |
+
self.cat_ids = self.coco.get_cat_ids(
|
246 |
+
cat_names=self.metainfo['classes'])
|
247 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
248 |
+
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
|
249 |
+
|
250 |
+
img_ids = self.coco.get_img_ids()
|
251 |
+
data_list = []
|
252 |
+
total_ann_ids = []
|
253 |
+
for img_id in img_ids:
|
254 |
+
raw_img_info = self.coco.load_imgs([img_id])[0]
|
255 |
+
raw_img_info['img_id'] = img_id
|
256 |
+
|
257 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
258 |
+
raw_ann_info = self.coco.load_anns(ann_ids)
|
259 |
+
total_ann_ids.extend(ann_ids)
|
260 |
+
|
261 |
+
# file_name should be `patchX/xxx.jpg`
|
262 |
+
file_name = osp.join(
|
263 |
+
osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
|
264 |
+
osp.split(raw_img_info['file_name'])[-1])
|
265 |
+
|
266 |
+
if file_name in objv2_ignore_list:
|
267 |
+
continue
|
268 |
+
|
269 |
+
raw_img_info['file_name'] = file_name
|
270 |
+
parsed_data_info = self.parse_data_info({
|
271 |
+
'raw_ann_info':
|
272 |
+
raw_ann_info,
|
273 |
+
'raw_img_info':
|
274 |
+
raw_img_info
|
275 |
+
})
|
276 |
+
data_list.append(parsed_data_info)
|
277 |
+
if self.ANN_ID_UNIQUE:
|
278 |
+
assert len(set(total_ann_ids)) == len(
|
279 |
+
total_ann_ids
|
280 |
+
), f"Annotation ids in '{self.ann_file}' are not unique!"
|
281 |
+
|
282 |
+
del self.coco
|
283 |
+
|
284 |
+
return data_list
|
mmdet/datasets/openimages.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import csv
|
3 |
+
import os.path as osp
|
4 |
+
from collections import defaultdict
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from mmengine.fileio import get_local_path, load
|
9 |
+
from mmengine.utils import is_abs
|
10 |
+
|
11 |
+
from mmdet.registry import DATASETS
|
12 |
+
from .base_det_dataset import BaseDetDataset
|
13 |
+
|
14 |
+
|
15 |
+
@DATASETS.register_module()
|
16 |
+
class OpenImagesDataset(BaseDetDataset):
|
17 |
+
"""Open Images dataset for detection.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
ann_file (str): Annotation file path.
|
21 |
+
label_file (str): File path of the label description file that
|
22 |
+
maps the classes names in MID format to their short
|
23 |
+
descriptions.
|
24 |
+
meta_file (str): File path to get image metas.
|
25 |
+
hierarchy_file (str): The file path of the class hierarchy.
|
26 |
+
image_level_ann_file (str): Human-verified image level annotation,
|
27 |
+
which is used in evaluation.
|
28 |
+
backend_args (dict, optional): Arguments to instantiate the
|
29 |
+
corresponding backend. Defaults to None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
METAINFO: dict = dict(dataset_type='oid_v6')
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
label_file: str,
|
36 |
+
meta_file: str,
|
37 |
+
hierarchy_file: str,
|
38 |
+
image_level_ann_file: Optional[str] = None,
|
39 |
+
**kwargs) -> None:
|
40 |
+
self.label_file = label_file
|
41 |
+
self.meta_file = meta_file
|
42 |
+
self.hierarchy_file = hierarchy_file
|
43 |
+
self.image_level_ann_file = image_level_ann_file
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
def load_data_list(self) -> List[dict]:
|
47 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
List[dict]: A list of annotation.
|
51 |
+
"""
|
52 |
+
classes_names, label_id_mapping = self._parse_label_file(
|
53 |
+
self.label_file)
|
54 |
+
self._metainfo['classes'] = classes_names
|
55 |
+
self.label_id_mapping = label_id_mapping
|
56 |
+
|
57 |
+
if self.image_level_ann_file is not None:
|
58 |
+
img_level_anns = self._parse_img_level_ann(
|
59 |
+
self.image_level_ann_file)
|
60 |
+
else:
|
61 |
+
img_level_anns = None
|
62 |
+
|
63 |
+
# OpenImagesMetric can get the relation matrix from the dataset meta
|
64 |
+
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
|
65 |
+
self._metainfo['RELATION_MATRIX'] = relation_matrix
|
66 |
+
|
67 |
+
data_list = []
|
68 |
+
with get_local_path(
|
69 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
70 |
+
with open(local_path, 'r') as f:
|
71 |
+
reader = csv.reader(f)
|
72 |
+
last_img_id = None
|
73 |
+
instances = []
|
74 |
+
for i, line in enumerate(reader):
|
75 |
+
if i == 0:
|
76 |
+
continue
|
77 |
+
img_id = line[0]
|
78 |
+
if last_img_id is None:
|
79 |
+
last_img_id = img_id
|
80 |
+
label_id = line[2]
|
81 |
+
assert label_id in self.label_id_mapping
|
82 |
+
label = int(self.label_id_mapping[label_id])
|
83 |
+
bbox = [
|
84 |
+
float(line[4]), # xmin
|
85 |
+
float(line[6]), # ymin
|
86 |
+
float(line[5]), # xmax
|
87 |
+
float(line[7]) # ymax
|
88 |
+
]
|
89 |
+
is_occluded = True if int(line[8]) == 1 else False
|
90 |
+
is_truncated = True if int(line[9]) == 1 else False
|
91 |
+
is_group_of = True if int(line[10]) == 1 else False
|
92 |
+
is_depiction = True if int(line[11]) == 1 else False
|
93 |
+
is_inside = True if int(line[12]) == 1 else False
|
94 |
+
|
95 |
+
instance = dict(
|
96 |
+
bbox=bbox,
|
97 |
+
bbox_label=label,
|
98 |
+
ignore_flag=0,
|
99 |
+
is_occluded=is_occluded,
|
100 |
+
is_truncated=is_truncated,
|
101 |
+
is_group_of=is_group_of,
|
102 |
+
is_depiction=is_depiction,
|
103 |
+
is_inside=is_inside)
|
104 |
+
last_img_path = osp.join(self.data_prefix['img'],
|
105 |
+
f'{last_img_id}.jpg')
|
106 |
+
if img_id != last_img_id:
|
107 |
+
# switch to a new image, record previous image's data.
|
108 |
+
data_info = dict(
|
109 |
+
img_path=last_img_path,
|
110 |
+
img_id=last_img_id,
|
111 |
+
instances=instances,
|
112 |
+
)
|
113 |
+
data_list.append(data_info)
|
114 |
+
instances = []
|
115 |
+
instances.append(instance)
|
116 |
+
last_img_id = img_id
|
117 |
+
data_list.append(
|
118 |
+
dict(
|
119 |
+
img_path=last_img_path,
|
120 |
+
img_id=last_img_id,
|
121 |
+
instances=instances,
|
122 |
+
))
|
123 |
+
|
124 |
+
# add image metas to data list
|
125 |
+
img_metas = load(
|
126 |
+
self.meta_file, file_format='pkl', backend_args=self.backend_args)
|
127 |
+
assert len(img_metas) == len(data_list)
|
128 |
+
for i, meta in enumerate(img_metas):
|
129 |
+
img_id = data_list[i]['img_id']
|
130 |
+
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
|
131 |
+
h, w = meta['ori_shape'][:2]
|
132 |
+
data_list[i]['height'] = h
|
133 |
+
data_list[i]['width'] = w
|
134 |
+
# denormalize bboxes
|
135 |
+
for j in range(len(data_list[i]['instances'])):
|
136 |
+
data_list[i]['instances'][j]['bbox'][0] *= w
|
137 |
+
data_list[i]['instances'][j]['bbox'][2] *= w
|
138 |
+
data_list[i]['instances'][j]['bbox'][1] *= h
|
139 |
+
data_list[i]['instances'][j]['bbox'][3] *= h
|
140 |
+
# add image-level annotation
|
141 |
+
if img_level_anns is not None:
|
142 |
+
img_labels = []
|
143 |
+
confidences = []
|
144 |
+
img_ann_list = img_level_anns.get(img_id, [])
|
145 |
+
for ann in img_ann_list:
|
146 |
+
img_labels.append(int(ann['image_level_label']))
|
147 |
+
confidences.append(float(ann['confidence']))
|
148 |
+
data_list[i]['image_level_labels'] = np.array(
|
149 |
+
img_labels, dtype=np.int64)
|
150 |
+
data_list[i]['confidences'] = np.array(
|
151 |
+
confidences, dtype=np.float32)
|
152 |
+
return data_list
|
153 |
+
|
154 |
+
def _parse_label_file(self, label_file: str) -> tuple:
|
155 |
+
"""Get classes name and index mapping from cls-label-description file.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
label_file (str): File path of the label description file that
|
159 |
+
maps the classes names in MID format to their short
|
160 |
+
descriptions.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
tuple: Class name of OpenImages.
|
164 |
+
"""
|
165 |
+
|
166 |
+
index_list = []
|
167 |
+
classes_names = []
|
168 |
+
with get_local_path(
|
169 |
+
label_file, backend_args=self.backend_args) as local_path:
|
170 |
+
with open(local_path, 'r') as f:
|
171 |
+
reader = csv.reader(f)
|
172 |
+
for line in reader:
|
173 |
+
# self.cat2label[line[0]] = line[1]
|
174 |
+
classes_names.append(line[1])
|
175 |
+
index_list.append(line[0])
|
176 |
+
index_mapping = {index: i for i, index in enumerate(index_list)}
|
177 |
+
return classes_names, index_mapping
|
178 |
+
|
179 |
+
def _parse_img_level_ann(self,
|
180 |
+
img_level_ann_file: str) -> Dict[str, List[dict]]:
|
181 |
+
"""Parse image level annotations from csv style ann_file.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
img_level_ann_file (str): CSV style image level annotation
|
185 |
+
file path.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Dict[str, List[dict]]: Annotations where item of the defaultdict
|
189 |
+
indicates an image, each of which has (n) dicts.
|
190 |
+
Keys of dicts are:
|
191 |
+
|
192 |
+
- `image_level_label` (int): Label id.
|
193 |
+
- `confidence` (float): Labels that are human-verified to be
|
194 |
+
present in an image have confidence = 1 (positive labels).
|
195 |
+
Labels that are human-verified to be absent from an image
|
196 |
+
have confidence = 0 (negative labels). Machine-generated
|
197 |
+
labels have fractional confidences, generally >= 0.5.
|
198 |
+
The higher the confidence, the smaller the chance for
|
199 |
+
the label to be a false positive.
|
200 |
+
"""
|
201 |
+
|
202 |
+
item_lists = defaultdict(list)
|
203 |
+
with get_local_path(
|
204 |
+
img_level_ann_file,
|
205 |
+
backend_args=self.backend_args) as local_path:
|
206 |
+
with open(local_path, 'r') as f:
|
207 |
+
reader = csv.reader(f)
|
208 |
+
for i, line in enumerate(reader):
|
209 |
+
if i == 0:
|
210 |
+
continue
|
211 |
+
img_id = line[0]
|
212 |
+
item_lists[img_id].append(
|
213 |
+
dict(
|
214 |
+
image_level_label=int(
|
215 |
+
self.label_id_mapping[line[2]]),
|
216 |
+
confidence=float(line[3])))
|
217 |
+
return item_lists
|
218 |
+
|
219 |
+
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
|
220 |
+
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
|
221 |
+
for 600 classes can be found at https://storage.googleapis.com/openimag
|
222 |
+
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
hierarchy_file (str): File path to the hierarchy for classes.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
np.ndarray: The matrix of the corresponding relationship between
|
229 |
+
the parent class and the child class, of shape
|
230 |
+
(class_num, class_num).
|
231 |
+
""" # noqa
|
232 |
+
|
233 |
+
hierarchy = load(
|
234 |
+
hierarchy_file, file_format='json', backend_args=self.backend_args)
|
235 |
+
class_num = len(self._metainfo['classes'])
|
236 |
+
relation_matrix = np.eye(class_num, class_num)
|
237 |
+
relation_matrix = self._convert_hierarchy_tree(hierarchy,
|
238 |
+
relation_matrix)
|
239 |
+
return relation_matrix
|
240 |
+
|
241 |
+
def _convert_hierarchy_tree(self,
|
242 |
+
hierarchy_map: dict,
|
243 |
+
relation_matrix: np.ndarray,
|
244 |
+
parents: list = [],
|
245 |
+
get_all_parents: bool = True) -> np.ndarray:
|
246 |
+
"""Get matrix of the corresponding relationship between the parent
|
247 |
+
class and the child class.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
hierarchy_map (dict): Including label name and corresponding
|
251 |
+
subcategory. Keys of dicts are:
|
252 |
+
|
253 |
+
- `LabeName` (str): Name of the label.
|
254 |
+
- `Subcategory` (dict | list): Corresponding subcategory(ies).
|
255 |
+
relation_matrix (ndarray): The matrix of the corresponding
|
256 |
+
relationship between the parent class and the child class,
|
257 |
+
of shape (class_num, class_num).
|
258 |
+
parents (list): Corresponding parent class.
|
259 |
+
get_all_parents (bool): Whether get all parent names.
|
260 |
+
Default: True
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
ndarray: The matrix of the corresponding relationship between
|
264 |
+
the parent class and the child class, of shape
|
265 |
+
(class_num, class_num).
|
266 |
+
"""
|
267 |
+
|
268 |
+
if 'Subcategory' in hierarchy_map:
|
269 |
+
for node in hierarchy_map['Subcategory']:
|
270 |
+
if 'LabelName' in node:
|
271 |
+
children_name = node['LabelName']
|
272 |
+
children_index = self.label_id_mapping[children_name]
|
273 |
+
children = [children_index]
|
274 |
+
else:
|
275 |
+
continue
|
276 |
+
if len(parents) > 0:
|
277 |
+
for parent_index in parents:
|
278 |
+
if get_all_parents:
|
279 |
+
children.append(parent_index)
|
280 |
+
relation_matrix[children_index, parent_index] = 1
|
281 |
+
relation_matrix = self._convert_hierarchy_tree(
|
282 |
+
node, relation_matrix, parents=children)
|
283 |
+
return relation_matrix
|
284 |
+
|
285 |
+
def _join_prefix(self):
|
286 |
+
"""Join ``self.data_root`` with annotation path."""
|
287 |
+
super()._join_prefix()
|
288 |
+
if not is_abs(self.label_file) and self.label_file:
|
289 |
+
self.label_file = osp.join(self.data_root, self.label_file)
|
290 |
+
if not is_abs(self.meta_file) and self.meta_file:
|
291 |
+
self.meta_file = osp.join(self.data_root, self.meta_file)
|
292 |
+
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
|
293 |
+
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
|
294 |
+
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
|
295 |
+
self.image_level_ann_file = osp.join(self.data_root,
|
296 |
+
self.image_level_ann_file)
|
297 |
+
|
298 |
+
|
299 |
+
@DATASETS.register_module()
|
300 |
+
class OpenImagesChallengeDataset(OpenImagesDataset):
|
301 |
+
"""Open Images Challenge dataset for detection.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
ann_file (str): Open Images Challenge box annotation in txt format.
|
305 |
+
"""
|
306 |
+
|
307 |
+
METAINFO: dict = dict(dataset_type='oid_challenge')
|
308 |
+
|
309 |
+
def __init__(self, ann_file: str, **kwargs) -> None:
|
310 |
+
if not ann_file.endswith('txt'):
|
311 |
+
raise TypeError('The annotation file of Open Images Challenge '
|
312 |
+
'should be a txt file.')
|
313 |
+
|
314 |
+
super().__init__(ann_file=ann_file, **kwargs)
|
315 |
+
|
316 |
+
def load_data_list(self) -> List[dict]:
|
317 |
+
"""Load annotations from an annotation file named as ``self.ann_file``
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
List[dict]: A list of annotation.
|
321 |
+
"""
|
322 |
+
classes_names, label_id_mapping = self._parse_label_file(
|
323 |
+
self.label_file)
|
324 |
+
self._metainfo['classes'] = classes_names
|
325 |
+
self.label_id_mapping = label_id_mapping
|
326 |
+
|
327 |
+
if self.image_level_ann_file is not None:
|
328 |
+
img_level_anns = self._parse_img_level_ann(
|
329 |
+
self.image_level_ann_file)
|
330 |
+
else:
|
331 |
+
img_level_anns = None
|
332 |
+
|
333 |
+
# OpenImagesMetric can get the relation matrix from the dataset meta
|
334 |
+
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
|
335 |
+
self._metainfo['RELATION_MATRIX'] = relation_matrix
|
336 |
+
|
337 |
+
data_list = []
|
338 |
+
with get_local_path(
|
339 |
+
self.ann_file, backend_args=self.backend_args) as local_path:
|
340 |
+
with open(local_path, 'r') as f:
|
341 |
+
lines = f.readlines()
|
342 |
+
i = 0
|
343 |
+
while i < len(lines):
|
344 |
+
instances = []
|
345 |
+
filename = lines[i].rstrip()
|
346 |
+
i += 2
|
347 |
+
img_gt_size = int(lines[i])
|
348 |
+
i += 1
|
349 |
+
for j in range(img_gt_size):
|
350 |
+
sp = lines[i + j].split()
|
351 |
+
instances.append(
|
352 |
+
dict(
|
353 |
+
bbox=[
|
354 |
+
float(sp[1]),
|
355 |
+
float(sp[2]),
|
356 |
+
float(sp[3]),
|
357 |
+
float(sp[4])
|
358 |
+
],
|
359 |
+
bbox_label=int(sp[0]) - 1, # labels begin from 1
|
360 |
+
ignore_flag=0,
|
361 |
+
is_group_ofs=True if int(sp[5]) == 1 else False))
|
362 |
+
i += img_gt_size
|
363 |
+
data_list.append(
|
364 |
+
dict(
|
365 |
+
img_path=osp.join(self.data_prefix['img'], filename),
|
366 |
+
instances=instances,
|
367 |
+
))
|
368 |
+
|
369 |
+
# add image metas to data list
|
370 |
+
img_metas = load(
|
371 |
+
self.meta_file, file_format='pkl', backend_args=self.backend_args)
|
372 |
+
assert len(img_metas) == len(data_list)
|
373 |
+
for i, meta in enumerate(img_metas):
|
374 |
+
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
|
375 |
+
assert img_id == osp.split(meta['filename'])[-1][:-4]
|
376 |
+
h, w = meta['ori_shape'][:2]
|
377 |
+
data_list[i]['height'] = h
|
378 |
+
data_list[i]['width'] = w
|
379 |
+
data_list[i]['img_id'] = img_id
|
380 |
+
# denormalize bboxes
|
381 |
+
for j in range(len(data_list[i]['instances'])):
|
382 |
+
data_list[i]['instances'][j]['bbox'][0] *= w
|
383 |
+
data_list[i]['instances'][j]['bbox'][2] *= w
|
384 |
+
data_list[i]['instances'][j]['bbox'][1] *= h
|
385 |
+
data_list[i]['instances'][j]['bbox'][3] *= h
|
386 |
+
# add image-level annotation
|
387 |
+
if img_level_anns is not None:
|
388 |
+
img_labels = []
|
389 |
+
confidences = []
|
390 |
+
img_ann_list = img_level_anns.get(img_id, [])
|
391 |
+
for ann in img_ann_list:
|
392 |
+
img_labels.append(int(ann['image_level_label']))
|
393 |
+
confidences.append(float(ann['confidence']))
|
394 |
+
data_list[i]['image_level_labels'] = np.array(
|
395 |
+
img_labels, dtype=np.int64)
|
396 |
+
data_list[i]['confidences'] = np.array(
|
397 |
+
confidences, dtype=np.float32)
|
398 |
+
return data_list
|
399 |
+
|
400 |
+
def _parse_label_file(self, label_file: str) -> tuple:
|
401 |
+
"""Get classes name and index mapping from cls-label-description file.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
label_file (str): File path of the label description file that
|
405 |
+
maps the classes names in MID format to their short
|
406 |
+
descriptions.
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
tuple: Class name of OpenImages.
|
410 |
+
"""
|
411 |
+
label_list = []
|
412 |
+
id_list = []
|
413 |
+
index_mapping = {}
|
414 |
+
with get_local_path(
|
415 |
+
label_file, backend_args=self.backend_args) as local_path:
|
416 |
+
with open(local_path, 'r') as f:
|
417 |
+
reader = csv.reader(f)
|
418 |
+
for line in reader:
|
419 |
+
label_name = line[0]
|
420 |
+
label_id = int(line[2])
|
421 |
+
label_list.append(line[1])
|
422 |
+
id_list.append(label_id)
|
423 |
+
index_mapping[label_name] = label_id - 1
|
424 |
+
indexes = np.argsort(id_list)
|
425 |
+
classes_names = []
|
426 |
+
for index in indexes:
|
427 |
+
classes_names.append(label_list[index])
|
428 |
+
return classes_names, index_mapping
|
429 |
+
|
430 |
+
def _parse_img_level_ann(self, image_level_ann_file):
|
431 |
+
"""Parse image level annotations from csv style ann_file.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
image_level_ann_file (str): CSV style image level annotation
|
435 |
+
file path.
|
436 |
+
|
437 |
+
Returns:
|
438 |
+
defaultdict[list[dict]]: Annotations where item of the defaultdict
|
439 |
+
indicates an image, each of which has (n) dicts.
|
440 |
+
Keys of dicts are:
|
441 |
+
|
442 |
+
- `image_level_label` (int): of shape 1.
|
443 |
+
- `confidence` (float): of shape 1.
|
444 |
+
"""
|
445 |
+
|
446 |
+
item_lists = defaultdict(list)
|
447 |
+
with get_local_path(
|
448 |
+
image_level_ann_file,
|
449 |
+
backend_args=self.backend_args) as local_path:
|
450 |
+
with open(local_path, 'r') as f:
|
451 |
+
reader = csv.reader(f)
|
452 |
+
i = -1
|
453 |
+
for line in reader:
|
454 |
+
i += 1
|
455 |
+
if i == 0:
|
456 |
+
continue
|
457 |
+
else:
|
458 |
+
img_id = line[0]
|
459 |
+
label_id = line[1]
|
460 |
+
assert label_id in self.label_id_mapping
|
461 |
+
image_level_label = int(
|
462 |
+
self.label_id_mapping[label_id])
|
463 |
+
confidence = float(line[2])
|
464 |
+
item_lists[img_id].append(
|
465 |
+
dict(
|
466 |
+
image_level_label=image_level_label,
|
467 |
+
confidence=confidence))
|
468 |
+
return item_lists
|
469 |
+
|
470 |
+
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
|
471 |
+
"""Get the matrix of class hierarchy from the hierarchy file.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
hierarchy_file (str): File path to the hierarchy for classes.
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
np.ndarray: The matrix of the corresponding
|
478 |
+
relationship between the parent class and the child class,
|
479 |
+
of shape (class_num, class_num).
|
480 |
+
"""
|
481 |
+
with get_local_path(
|
482 |
+
hierarchy_file, backend_args=self.backend_args) as local_path:
|
483 |
+
class_label_tree = np.load(local_path, allow_pickle=True)
|
484 |
+
return class_label_tree[1:, 1:]
|
mmdet/datasets/samplers/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .batch_sampler import AspectRatioBatchSampler
|
3 |
+
from .class_aware_sampler import ClassAwareSampler
|
4 |
+
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
|
8 |
+
'GroupMultiSourceSampler'
|
9 |
+
]
|
mmdet/datasets/samplers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (421 Bytes). View file
|
|
mmdet/datasets/samplers/__pycache__/batch_sampler.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
mmdet/datasets/samplers/__pycache__/class_aware_sampler.cpython-310.pyc
ADDED
Binary file (6.95 kB). View file
|
|
mmdet/datasets/samplers/__pycache__/multi_source_sampler.cpython-310.pyc
ADDED
Binary file (8.51 kB). View file
|
|
mmdet/datasets/samplers/batch_sampler.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
from torch.utils.data import BatchSampler, Sampler
|
5 |
+
|
6 |
+
from mmdet.registry import DATA_SAMPLERS
|
7 |
+
|
8 |
+
|
9 |
+
# TODO: maybe replace with a data_loader wrapper
|
10 |
+
@DATA_SAMPLERS.register_module()
|
11 |
+
class AspectRatioBatchSampler(BatchSampler):
|
12 |
+
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
|
13 |
+
|
14 |
+
>= 1) into a same batch.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
sampler (Sampler): Base sampler.
|
18 |
+
batch_size (int): Size of mini-batch.
|
19 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
20 |
+
its size would be less than ``batch_size``.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
sampler: Sampler,
|
25 |
+
batch_size: int,
|
26 |
+
drop_last: bool = False) -> None:
|
27 |
+
if not isinstance(sampler, Sampler):
|
28 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
29 |
+
f'but got {sampler}')
|
30 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
31 |
+
raise ValueError('batch_size should be a positive integer value, '
|
32 |
+
f'but got batch_size={batch_size}')
|
33 |
+
self.sampler = sampler
|
34 |
+
self.batch_size = batch_size
|
35 |
+
self.drop_last = drop_last
|
36 |
+
# two groups for w < h and w >= h
|
37 |
+
self._aspect_ratio_buckets = [[] for _ in range(2)]
|
38 |
+
|
39 |
+
def __iter__(self) -> Sequence[int]:
|
40 |
+
for idx in self.sampler:
|
41 |
+
data_info = self.sampler.dataset.get_data_info(idx)
|
42 |
+
width, height = data_info['width'], data_info['height']
|
43 |
+
bucket_id = 0 if width < height else 1
|
44 |
+
bucket = self._aspect_ratio_buckets[bucket_id]
|
45 |
+
bucket.append(idx)
|
46 |
+
# yield a batch of indices in the same aspect ratio group
|
47 |
+
if len(bucket) == self.batch_size:
|
48 |
+
yield bucket[:]
|
49 |
+
del bucket[:]
|
50 |
+
|
51 |
+
# yield the rest data and reset the bucket
|
52 |
+
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
|
53 |
+
1]
|
54 |
+
self._aspect_ratio_buckets = [[] for _ in range(2)]
|
55 |
+
while len(left_data) > 0:
|
56 |
+
if len(left_data) <= self.batch_size:
|
57 |
+
if not self.drop_last:
|
58 |
+
yield left_data[:]
|
59 |
+
left_data = []
|
60 |
+
else:
|
61 |
+
yield left_data[:self.batch_size]
|
62 |
+
left_data = left_data[self.batch_size:]
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
if self.drop_last:
|
66 |
+
return len(self.sampler) // self.batch_size
|
67 |
+
else:
|
68 |
+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
mmdet/datasets/samplers/class_aware_sampler.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import Dict, Iterator, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mmengine.dataset import BaseDataset
|
8 |
+
from mmengine.dist import get_dist_info, sync_random_seed
|
9 |
+
from torch.utils.data import Sampler
|
10 |
+
|
11 |
+
from mmdet.registry import DATA_SAMPLERS
|
12 |
+
|
13 |
+
|
14 |
+
@DATA_SAMPLERS.register_module()
|
15 |
+
class ClassAwareSampler(Sampler):
|
16 |
+
r"""Sampler that restricts data loading to the label of the dataset.
|
17 |
+
|
18 |
+
A class-aware sampling strategy to effectively tackle the
|
19 |
+
non-uniform class distribution. The length of the training data is
|
20 |
+
consistent with source data. Simple improvements based on `Relay
|
21 |
+
Backpropagation for Effective Learning of Deep Convolutional
|
22 |
+
Neural Networks <https://arxiv.org/abs/1512.05830>`_
|
23 |
+
|
24 |
+
The implementation logic is referred to
|
25 |
+
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
|
26 |
+
|
27 |
+
Args:
|
28 |
+
dataset: Dataset used for sampling.
|
29 |
+
seed (int, optional): random seed used to shuffle the sampler.
|
30 |
+
This number should be identical across all
|
31 |
+
processes in the distributed group. Defaults to None.
|
32 |
+
num_sample_class (int): The number of samples taken from each
|
33 |
+
per-label list. Defaults to 1.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
dataset: BaseDataset,
|
38 |
+
seed: Optional[int] = None,
|
39 |
+
num_sample_class: int = 1) -> None:
|
40 |
+
rank, world_size = get_dist_info()
|
41 |
+
self.rank = rank
|
42 |
+
self.world_size = world_size
|
43 |
+
|
44 |
+
self.dataset = dataset
|
45 |
+
self.epoch = 0
|
46 |
+
# Must be the same across all workers. If None, will use a
|
47 |
+
# random seed shared among workers
|
48 |
+
# (require synchronization among all workers)
|
49 |
+
if seed is None:
|
50 |
+
seed = sync_random_seed()
|
51 |
+
self.seed = seed
|
52 |
+
|
53 |
+
# The number of samples taken from each per-label list
|
54 |
+
assert num_sample_class > 0 and isinstance(num_sample_class, int)
|
55 |
+
self.num_sample_class = num_sample_class
|
56 |
+
# Get per-label image list from dataset
|
57 |
+
self.cat_dict = self.get_cat2imgs()
|
58 |
+
|
59 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
|
60 |
+
self.total_size = self.num_samples * self.world_size
|
61 |
+
|
62 |
+
# get number of images containing each category
|
63 |
+
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
|
64 |
+
# filter labels without images
|
65 |
+
self.valid_cat_inds = [
|
66 |
+
i for i, length in enumerate(self.num_cat_imgs) if length != 0
|
67 |
+
]
|
68 |
+
self.num_classes = len(self.valid_cat_inds)
|
69 |
+
|
70 |
+
def get_cat2imgs(self) -> Dict[int, list]:
|
71 |
+
"""Get a dict with class as key and img_ids as values.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
dict[int, list]: A dict of per-label image list,
|
75 |
+
the item of the dict indicates a label index,
|
76 |
+
corresponds to the image index that contains the label.
|
77 |
+
"""
|
78 |
+
classes = self.dataset.metainfo.get('classes', None)
|
79 |
+
if classes is None:
|
80 |
+
raise ValueError('dataset metainfo must contain `classes`')
|
81 |
+
# sort the label index
|
82 |
+
cat2imgs = {i: [] for i in range(len(classes))}
|
83 |
+
for i in range(len(self.dataset)):
|
84 |
+
cat_ids = set(self.dataset.get_cat_ids(i))
|
85 |
+
for cat in cat_ids:
|
86 |
+
cat2imgs[cat].append(i)
|
87 |
+
return cat2imgs
|
88 |
+
|
89 |
+
def __iter__(self) -> Iterator[int]:
|
90 |
+
# deterministically shuffle based on epoch
|
91 |
+
g = torch.Generator()
|
92 |
+
g.manual_seed(self.epoch + self.seed)
|
93 |
+
|
94 |
+
# initialize label list
|
95 |
+
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
|
96 |
+
# initialize each per-label image list
|
97 |
+
data_iter_dict = dict()
|
98 |
+
for i in self.valid_cat_inds:
|
99 |
+
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
|
100 |
+
|
101 |
+
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
|
102 |
+
"""Traverse the categories and extract `num_sample_cls` image
|
103 |
+
indexes of the corresponding categories one by one."""
|
104 |
+
id_indices = []
|
105 |
+
for _ in range(len(cls_list)):
|
106 |
+
cls_idx = next(cls_list)
|
107 |
+
for _ in range(num_sample_cls):
|
108 |
+
id = next(data_dict[cls_idx])
|
109 |
+
id_indices.append(id)
|
110 |
+
return id_indices
|
111 |
+
|
112 |
+
# deterministically shuffle based on epoch
|
113 |
+
num_bins = int(
|
114 |
+
math.ceil(self.total_size * 1.0 / self.num_classes /
|
115 |
+
self.num_sample_class))
|
116 |
+
indices = []
|
117 |
+
for i in range(num_bins):
|
118 |
+
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
|
119 |
+
self.num_sample_class)
|
120 |
+
|
121 |
+
# fix extra samples to make it evenly divisible
|
122 |
+
if len(indices) >= self.total_size:
|
123 |
+
indices = indices[:self.total_size]
|
124 |
+
else:
|
125 |
+
indices += indices[:(self.total_size - len(indices))]
|
126 |
+
assert len(indices) == self.total_size
|
127 |
+
|
128 |
+
# subsample
|
129 |
+
offset = self.num_samples * self.rank
|
130 |
+
indices = indices[offset:offset + self.num_samples]
|
131 |
+
assert len(indices) == self.num_samples
|
132 |
+
|
133 |
+
return iter(indices)
|
134 |
+
|
135 |
+
def __len__(self) -> int:
|
136 |
+
"""The number of samples in this rank."""
|
137 |
+
return self.num_samples
|
138 |
+
|
139 |
+
def set_epoch(self, epoch: int) -> None:
|
140 |
+
"""Sets the epoch for this sampler.
|
141 |
+
|
142 |
+
When :attr:`shuffle=True`, this ensures all replicas use a different
|
143 |
+
random ordering for each epoch. Otherwise, the next iteration of this
|
144 |
+
sampler will yield the same ordering.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
epoch (int): Epoch number.
|
148 |
+
"""
|
149 |
+
self.epoch = epoch
|
150 |
+
|
151 |
+
|
152 |
+
class RandomCycleIter:
|
153 |
+
"""Shuffle the list and do it again after the list have traversed.
|
154 |
+
|
155 |
+
The implementation logic is referred to
|
156 |
+
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
|
157 |
+
|
158 |
+
Example:
|
159 |
+
>>> label_list = [0, 1, 2, 4, 5]
|
160 |
+
>>> g = torch.Generator()
|
161 |
+
>>> g.manual_seed(0)
|
162 |
+
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
|
163 |
+
>>> index = next(label_iter_list)
|
164 |
+
Args:
|
165 |
+
data (list or ndarray): The data that needs to be shuffled.
|
166 |
+
generator: An torch.Generator object, which is used in setting the seed
|
167 |
+
for generating random numbers.
|
168 |
+
""" # noqa: W605
|
169 |
+
|
170 |
+
def __init__(self,
|
171 |
+
data: Union[list, np.ndarray],
|
172 |
+
generator: torch.Generator = None) -> None:
|
173 |
+
self.data = data
|
174 |
+
self.length = len(data)
|
175 |
+
self.index = torch.randperm(self.length, generator=generator).numpy()
|
176 |
+
self.i = 0
|
177 |
+
self.generator = generator
|
178 |
+
|
179 |
+
def __iter__(self) -> Iterator:
|
180 |
+
return self
|
181 |
+
|
182 |
+
def __len__(self) -> int:
|
183 |
+
return len(self.data)
|
184 |
+
|
185 |
+
def __next__(self):
|
186 |
+
if self.i == self.length:
|
187 |
+
self.index = torch.randperm(
|
188 |
+
self.length, generator=self.generator).numpy()
|
189 |
+
self.i = 0
|
190 |
+
idx = self.data[self.index[self.i]]
|
191 |
+
self.i += 1
|
192 |
+
return idx
|
mmdet/datasets/samplers/multi_source_sampler.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import itertools
|
3 |
+
from typing import Iterator, List, Optional, Sized, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mmengine.dataset import BaseDataset
|
8 |
+
from mmengine.dist import get_dist_info, sync_random_seed
|
9 |
+
from torch.utils.data import Sampler
|
10 |
+
|
11 |
+
from mmdet.registry import DATA_SAMPLERS
|
12 |
+
|
13 |
+
|
14 |
+
@DATA_SAMPLERS.register_module()
|
15 |
+
class MultiSourceSampler(Sampler):
|
16 |
+
r"""Multi-Source Infinite Sampler.
|
17 |
+
|
18 |
+
According to the sampling ratio, sample data from different
|
19 |
+
datasets to form batches.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dataset (Sized): The dataset.
|
23 |
+
batch_size (int): Size of mini-batch.
|
24 |
+
source_ratio (list[int | float]): The sampling ratio of different
|
25 |
+
source datasets in a mini-batch.
|
26 |
+
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
27 |
+
seed (int, optional): Random seed. If None, set a random seed.
|
28 |
+
Defaults to None.
|
29 |
+
|
30 |
+
Examples:
|
31 |
+
>>> dataset_type = 'ConcatDataset'
|
32 |
+
>>> sub_dataset_type = 'CocoDataset'
|
33 |
+
>>> data_root = 'data/coco/'
|
34 |
+
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
|
35 |
+
>>> unsup_ann = '../coco_semi_annos/' \
|
36 |
+
>>> 'instances_train2017.1@10-unlabeled.json'
|
37 |
+
>>> dataset = dict(type=dataset_type,
|
38 |
+
>>> datasets=[
|
39 |
+
>>> dict(
|
40 |
+
>>> type=sub_dataset_type,
|
41 |
+
>>> data_root=data_root,
|
42 |
+
>>> ann_file=sup_ann,
|
43 |
+
>>> data_prefix=dict(img='train2017/'),
|
44 |
+
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
|
45 |
+
>>> pipeline=sup_pipeline),
|
46 |
+
>>> dict(
|
47 |
+
>>> type=sub_dataset_type,
|
48 |
+
>>> data_root=data_root,
|
49 |
+
>>> ann_file=unsup_ann,
|
50 |
+
>>> data_prefix=dict(img='train2017/'),
|
51 |
+
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
|
52 |
+
>>> pipeline=unsup_pipeline),
|
53 |
+
>>> ])
|
54 |
+
>>> train_dataloader = dict(
|
55 |
+
>>> batch_size=5,
|
56 |
+
>>> num_workers=5,
|
57 |
+
>>> persistent_workers=True,
|
58 |
+
>>> sampler=dict(type='MultiSourceSampler',
|
59 |
+
>>> batch_size=5, source_ratio=[1, 4]),
|
60 |
+
>>> batch_sampler=None,
|
61 |
+
>>> dataset=dataset)
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
dataset: Sized,
|
66 |
+
batch_size: int,
|
67 |
+
source_ratio: List[Union[int, float]],
|
68 |
+
shuffle: bool = True,
|
69 |
+
seed: Optional[int] = None) -> None:
|
70 |
+
|
71 |
+
assert hasattr(dataset, 'cumulative_sizes'),\
|
72 |
+
f'The dataset must be ConcatDataset, but get {dataset}'
|
73 |
+
assert isinstance(batch_size, int) and batch_size > 0, \
|
74 |
+
'batch_size must be a positive integer value, ' \
|
75 |
+
f'but got batch_size={batch_size}'
|
76 |
+
assert isinstance(source_ratio, list), \
|
77 |
+
f'source_ratio must be a list, but got source_ratio={source_ratio}'
|
78 |
+
assert len(source_ratio) == len(dataset.cumulative_sizes), \
|
79 |
+
'The length of source_ratio must be equal to ' \
|
80 |
+
f'the number of datasets, but got source_ratio={source_ratio}'
|
81 |
+
|
82 |
+
rank, world_size = get_dist_info()
|
83 |
+
self.rank = rank
|
84 |
+
self.world_size = world_size
|
85 |
+
|
86 |
+
self.dataset = dataset
|
87 |
+
self.cumulative_sizes = [0] + dataset.cumulative_sizes
|
88 |
+
self.batch_size = batch_size
|
89 |
+
self.source_ratio = source_ratio
|
90 |
+
|
91 |
+
self.num_per_source = [
|
92 |
+
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
|
93 |
+
]
|
94 |
+
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
|
95 |
+
|
96 |
+
assert sum(self.num_per_source) == batch_size, \
|
97 |
+
'The sum of num_per_source must be equal to ' \
|
98 |
+
f'batch_size, but get {self.num_per_source}'
|
99 |
+
|
100 |
+
self.seed = sync_random_seed() if seed is None else seed
|
101 |
+
self.shuffle = shuffle
|
102 |
+
self.source2inds = {
|
103 |
+
source: self._indices_of_rank(len(ds))
|
104 |
+
for source, ds in enumerate(dataset.datasets)
|
105 |
+
}
|
106 |
+
|
107 |
+
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
|
108 |
+
"""Infinitely yield a sequence of indices."""
|
109 |
+
g = torch.Generator()
|
110 |
+
g.manual_seed(self.seed)
|
111 |
+
while True:
|
112 |
+
if self.shuffle:
|
113 |
+
yield from torch.randperm(sample_size, generator=g).tolist()
|
114 |
+
else:
|
115 |
+
yield from torch.arange(sample_size).tolist()
|
116 |
+
|
117 |
+
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
|
118 |
+
"""Slice the infinite indices by rank."""
|
119 |
+
yield from itertools.islice(
|
120 |
+
self._infinite_indices(sample_size), self.rank, None,
|
121 |
+
self.world_size)
|
122 |
+
|
123 |
+
def __iter__(self) -> Iterator[int]:
|
124 |
+
batch_buffer = []
|
125 |
+
while True:
|
126 |
+
for source, num in enumerate(self.num_per_source):
|
127 |
+
batch_buffer_per_source = []
|
128 |
+
for idx in self.source2inds[source]:
|
129 |
+
idx += self.cumulative_sizes[source]
|
130 |
+
batch_buffer_per_source.append(idx)
|
131 |
+
if len(batch_buffer_per_source) == num:
|
132 |
+
batch_buffer += batch_buffer_per_source
|
133 |
+
break
|
134 |
+
yield from batch_buffer
|
135 |
+
batch_buffer = []
|
136 |
+
|
137 |
+
def __len__(self) -> int:
|
138 |
+
return len(self.dataset)
|
139 |
+
|
140 |
+
def set_epoch(self, epoch: int) -> None:
|
141 |
+
"""Not supported in `epoch-based runner."""
|
142 |
+
pass
|
143 |
+
|
144 |
+
|
145 |
+
@DATA_SAMPLERS.register_module()
|
146 |
+
class GroupMultiSourceSampler(MultiSourceSampler):
|
147 |
+
r"""Group Multi-Source Infinite Sampler.
|
148 |
+
|
149 |
+
According to the sampling ratio, sample data from different
|
150 |
+
datasets but the same group to form batches.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
dataset (Sized): The dataset.
|
154 |
+
batch_size (int): Size of mini-batch.
|
155 |
+
source_ratio (list[int | float]): The sampling ratio of different
|
156 |
+
source datasets in a mini-batch.
|
157 |
+
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
158 |
+
seed (int, optional): Random seed. If None, set a random seed.
|
159 |
+
Defaults to None.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self,
|
163 |
+
dataset: BaseDataset,
|
164 |
+
batch_size: int,
|
165 |
+
source_ratio: List[Union[int, float]],
|
166 |
+
shuffle: bool = True,
|
167 |
+
seed: Optional[int] = None) -> None:
|
168 |
+
super().__init__(
|
169 |
+
dataset=dataset,
|
170 |
+
batch_size=batch_size,
|
171 |
+
source_ratio=source_ratio,
|
172 |
+
shuffle=shuffle,
|
173 |
+
seed=seed)
|
174 |
+
|
175 |
+
self._get_source_group_info()
|
176 |
+
self.group_source2inds = [{
|
177 |
+
source:
|
178 |
+
self._indices_of_rank(self.group2size_per_source[source][group])
|
179 |
+
for source in range(len(dataset.datasets))
|
180 |
+
} for group in range(len(self.group_ratio))]
|
181 |
+
|
182 |
+
def _get_source_group_info(self) -> None:
|
183 |
+
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
|
184 |
+
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
|
185 |
+
for source, dataset in enumerate(self.dataset.datasets):
|
186 |
+
for idx in range(len(dataset)):
|
187 |
+
data_info = dataset.get_data_info(idx)
|
188 |
+
width, height = data_info['width'], data_info['height']
|
189 |
+
group = 0 if width < height else 1
|
190 |
+
self.group2size_per_source[source][group] += 1
|
191 |
+
self.group2inds_per_source[source][group].append(idx)
|
192 |
+
|
193 |
+
self.group_sizes = np.zeros(2, dtype=np.int64)
|
194 |
+
for group2size in self.group2size_per_source:
|
195 |
+
for group, size in group2size.items():
|
196 |
+
self.group_sizes[group] += size
|
197 |
+
self.group_ratio = self.group_sizes / sum(self.group_sizes)
|
198 |
+
|
199 |
+
def __iter__(self) -> Iterator[int]:
|
200 |
+
batch_buffer = []
|
201 |
+
while True:
|
202 |
+
group = np.random.choice(
|
203 |
+
list(range(len(self.group_ratio))), p=self.group_ratio)
|
204 |
+
for source, num in enumerate(self.num_per_source):
|
205 |
+
batch_buffer_per_source = []
|
206 |
+
for idx in self.group_source2inds[group][source]:
|
207 |
+
idx = self.group2inds_per_source[source][group][
|
208 |
+
idx] + self.cumulative_sizes[source]
|
209 |
+
batch_buffer_per_source.append(idx)
|
210 |
+
if len(batch_buffer_per_source) == num:
|
211 |
+
batch_buffer += batch_buffer_per_source
|
212 |
+
break
|
213 |
+
yield from batch_buffer
|
214 |
+
batch_buffer = []
|
mmdet/datasets/transforms/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .augment_wrappers import AutoAugment, RandAugment
|
3 |
+
from .colorspace import (AutoContrast, Brightness, Color, ColorTransform,
|
4 |
+
Contrast, Equalize, Invert, Posterize, Sharpness,
|
5 |
+
Solarize, SolarizeAdd)
|
6 |
+
from .formatting import ImageToTensor, PackDetInputs, ToTensor, Transpose
|
7 |
+
from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX,
|
8 |
+
TranslateY)
|
9 |
+
from .instaboost import InstaBoost
|
10 |
+
from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
|
11 |
+
LoadEmptyAnnotations, LoadImageFromNDArray,
|
12 |
+
LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
|
13 |
+
LoadProposals)
|
14 |
+
from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
|
15 |
+
Expand, FixShapeResize, MinIoURandomCrop, MixUp,
|
16 |
+
Mosaic, Pad, PhotoMetricDistortion, RandomAffine,
|
17 |
+
RandomCenterCropPad, RandomCrop, RandomErasing,
|
18 |
+
RandomFlip, RandomShift, Resize, SegRescale,
|
19 |
+
YOLOXHSVRandomAug)
|
20 |
+
from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
|
24 |
+
'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
|
25 |
+
'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
|
26 |
+
'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
|
27 |
+
'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
|
28 |
+
'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
|
29 |
+
'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
|
30 |
+
'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
|
31 |
+
'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
|
32 |
+
'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
|
33 |
+
'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
|
34 |
+
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
|
35 |
+
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader'
|
36 |
+
]
|
mmdet/datasets/transforms/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.91 kB). View file
|
|
mmdet/datasets/transforms/__pycache__/augment_wrappers.cpython-310.pyc
ADDED
Binary file (9.14 kB). View file
|
|
mmdet/datasets/transforms/__pycache__/colorspace.cpython-310.pyc
ADDED
Binary file (16.2 kB). View file
|
|
mmdet/datasets/transforms/__pycache__/formatting.cpython-310.pyc
ADDED
Binary file (8.89 kB). View file
|
|