Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from typing import Dict, List, Optional, Sequence, Union | |
from mmseg.registry import DATASETS | |
from .basesegdataset import BaseSegDataset | |
try: | |
from dsdl.dataset import DSDLDataset | |
except ImportError: | |
DSDLDataset = None | |
class DSDLSegDataset(BaseSegDataset): | |
"""Dataset for dsdl segmentation. | |
Args: | |
specific_key_path(dict): Path of specific key which can not | |
be loaded by it's field name. | |
pre_transform(dict): pre-transform functions before loading. | |
used_labels(sequence): list of actual used classes in train steps, | |
this must be subset of class domain. | |
""" | |
METAINFO = {} | |
def __init__(self, | |
specific_key_path: Dict = {}, | |
pre_transform: Dict = {}, | |
used_labels: Optional[Sequence] = None, | |
**kwargs) -> None: | |
if DSDLDataset is None: | |
raise RuntimeError( | |
'Package dsdl is not installed. Please run "pip install dsdl".' | |
) | |
self.used_labels = used_labels | |
loc_config = dict(type='LocalFileReader', working_dir='') | |
if kwargs.get('data_root'): | |
kwargs['ann_file'] = os.path.join(kwargs['data_root'], | |
kwargs['ann_file']) | |
required_fields = ['Image', 'LabelMap'] | |
self.dsdldataset = DSDLDataset( | |
dsdl_yaml=kwargs['ann_file'], | |
location_config=loc_config, | |
required_fields=required_fields, | |
specific_key_path=specific_key_path, | |
transform=pre_transform, | |
) | |
BaseSegDataset.__init__(self, **kwargs) | |
def load_data_list(self) -> List[Dict]: | |
"""Load data info from a dsdl yaml file named as ``self.ann_file`` | |
Returns: | |
List[dict]: A list of data list. | |
""" | |
if self.used_labels: | |
self._metainfo['classes'] = tuple(self.used_labels) | |
self.label_map = self.get_label_map(self.used_labels) | |
else: | |
self._metainfo['classes'] = tuple(['background'] + | |
self.dsdldataset.class_names) | |
data_list = [] | |
for i, data in enumerate(self.dsdldataset): | |
datainfo = dict( | |
img_path=os.path.join(self.data_prefix['img_path'], | |
data['Image'][0].location), | |
seg_map_path=os.path.join(self.data_prefix['seg_map_path'], | |
data['LabelMap'][0].location), | |
label_map=self.label_map, | |
reduce_zero_label=self.reduce_zero_label, | |
seg_fields=[], | |
) | |
data_list.append(datainfo) | |
return data_list | |
def get_label_map(self, | |
new_classes: Optional[Sequence] = None | |
) -> Union[Dict, None]: | |
"""Require label mapping. | |
The ``label_map`` is a dictionary, its keys are the old label ids and | |
its values are the new label ids, and is used for changing pixel | |
labels in load_annotations. If and only if old classes in class_dom | |
is not equal to new classes in args and nether of them is not | |
None, `label_map` is not None. | |
Args: | |
new_classes (list, tuple, optional): The new classes name from | |
metainfo. Default to None. | |
Returns: | |
dict, optional: The mapping from old classes to new classes. | |
""" | |
old_classes = ['background'] + self.dsdldataset.class_names | |
if (new_classes is not None and old_classes is not None | |
and list(new_classes) != list(old_classes)): | |
label_map = {} | |
if not set(new_classes).issubset(old_classes): | |
raise ValueError( | |
f'new classes {new_classes} is not a ' | |
f'subset of classes {old_classes} in class_dom.') | |
for i, c in enumerate(old_classes): | |
if c not in new_classes: | |
label_map[i] = 255 | |
else: | |
label_map[i] = new_classes.index(c) | |
return label_map | |
else: | |
return None | |