Spaces:
Runtime error
Runtime error
File size: 4,796 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional
from mmengine.dataset import BaseDataset
from mmengine.fileio import load
from mmengine.utils import is_abs
from ..registry import DATASETS
@DATASETS.register_module()
class BaseDetDataset(BaseDataset):
"""Base dataset for detection.
Args:
proposal_file (str, optional): Proposals file path. Defaults to None.
file_client_args (dict): Arguments to instantiate the
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
def __init__(self,
*args,
seg_map_suffix: str = '.png',
proposal_file: Optional[str] = None,
file_client_args: dict = None,
backend_args: dict = None,
**kwargs) -> None:
self.seg_map_suffix = seg_map_suffix
self.proposal_file = proposal_file
self.backend_args = backend_args
if file_client_args is not None:
raise RuntimeError(
'The `file_client_args` is deprecated, '
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
super().__init__(*args, **kwargs)
def full_init(self) -> None:
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True.
If ``lazy_init=False``, ``full_init`` will be called during the
instantiation and ``self._fully_initialized`` will be set to True. If
``obj._fully_initialized=False``, the class method decorated by
``force_full_init`` will call ``full_init`` automatically.
Several steps to initialize annotation:
- load_data_list: Load annotations from annotation file.
- load_proposals: Load proposals from proposal file, if
`self.proposal_file` is not None.
- filter data information: Filter annotations according to
filter_cfg.
- slice_data: Slice dataset according to ``self._indices``
- serialize_data: Serialize ``self.data_list`` if
``self.serialize_data`` is True.
"""
if self._fully_initialized:
return
# load data information
self.data_list = self.load_data_list()
# get proposals from file
if self.proposal_file is not None:
self.load_proposals()
# filter illegal data, such as data that has no annotations.
self.data_list = self.filter_data()
# Get subset data according to indices.
if self._indices is not None:
self.data_list = self._get_unserialized_subset(self._indices)
# serialize data_list
if self.serialize_data:
self.data_bytes, self.data_address = self._serialize_data()
self._fully_initialized = True
def load_proposals(self) -> None:
"""Load proposals from proposals file.
The `proposals_list` should be a dict[img_path: proposals]
with the same length as `data_list`. And the `proposals` should be
a `dict` or :obj:`InstanceData` usually contains following keys.
- bboxes (np.ndarry): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- scores (np.ndarry): Classification scores, has a shape
(num_instance, ).
"""
# TODO: Add Unit Test after fully support Dump-Proposal Metric
if not is_abs(self.proposal_file):
self.proposal_file = osp.join(self.data_root, self.proposal_file)
proposals_list = load(
self.proposal_file, backend_args=self.backend_args)
assert len(self.data_list) == len(proposals_list)
for data_info in self.data_list:
img_path = data_info['img_path']
# `file_name` is the key to obtain the proposals from the
# `proposals_list`.
file_name = osp.join(
osp.split(osp.split(img_path)[0])[-1],
osp.split(img_path)[-1])
proposals = proposals_list[file_name]
data_info['proposals'] = proposals
def get_cat_ids(self, idx: int) -> List[int]:
"""Get COCO category ids by index.
Args:
idx (int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]
|