TTP / mmseg /datasets /dsdl.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List, Optional, Sequence, Union
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
try:
from dsdl.dataset import DSDLDataset
except ImportError:
DSDLDataset = None
@DATASETS.register_module()
class DSDLSegDataset(BaseSegDataset):
"""Dataset for dsdl segmentation.
Args:
specific_key_path(dict): Path of specific key which can not
be loaded by it's field name.
pre_transform(dict): pre-transform functions before loading.
used_labels(sequence): list of actual used classes in train steps,
this must be subset of class domain.
"""
METAINFO = {}
def __init__(self,
specific_key_path: Dict = {},
pre_transform: Dict = {},
used_labels: Optional[Sequence] = None,
**kwargs) -> None:
if DSDLDataset is None:
raise RuntimeError(
'Package dsdl is not installed. Please run "pip install dsdl".'
)
self.used_labels = used_labels
loc_config = dict(type='LocalFileReader', working_dir='')
if kwargs.get('data_root'):
kwargs['ann_file'] = os.path.join(kwargs['data_root'],
kwargs['ann_file'])
required_fields = ['Image', 'LabelMap']
self.dsdldataset = DSDLDataset(
dsdl_yaml=kwargs['ann_file'],
location_config=loc_config,
required_fields=required_fields,
specific_key_path=specific_key_path,
transform=pre_transform,
)
BaseSegDataset.__init__(self, **kwargs)
def load_data_list(self) -> List[Dict]:
"""Load data info from a dsdl yaml file named as ``self.ann_file``
Returns:
List[dict]: A list of data list.
"""
if self.used_labels:
self._metainfo['classes'] = tuple(self.used_labels)
self.label_map = self.get_label_map(self.used_labels)
else:
self._metainfo['classes'] = tuple(['background'] +
self.dsdldataset.class_names)
data_list = []
for i, data in enumerate(self.dsdldataset):
datainfo = dict(
img_path=os.path.join(self.data_prefix['img_path'],
data['Image'][0].location),
seg_map_path=os.path.join(self.data_prefix['seg_map_path'],
data['LabelMap'][0].location),
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label,
seg_fields=[],
)
data_list.append(datainfo)
return data_list
def get_label_map(self,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require label mapping.
The ``label_map`` is a dictionary, its keys are the old label ids and
its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in class_dom
is not equal to new classes in args and nether of them is not
None, `label_map` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
dict, optional: The mapping from old classes to new classes.
"""
old_classes = ['background'] + self.dsdldataset.class_names
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_map = {}
if not set(new_classes).issubset(old_classes):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of classes {old_classes} in class_dom.')
for i, c in enumerate(old_classes):
if c not in new_classes:
label_map[i] = 255
else:
label_map[i] = new_classes.index(c)
return label_map
else:
return None