Spaces:
Runtime error
Runtime error
# Copyright (c) Open-CD. All rights reserved. | |
import copy | |
import os.path as osp | |
from typing import Dict, List, Optional, Sequence, Union | |
import mmengine | |
import mmengine.fileio as fileio | |
import numpy as np | |
from mmseg.registry import DATASETS | |
from .basecddataset import _BaseCDDataset | |
class BaseSCDDataset(_BaseCDDataset): | |
def __init__(self, | |
lazy_init=False, | |
reduce_semantic_zero_label=False, | |
**kwargs): | |
super().__init__(lazy_init=True, **kwargs) | |
self.reduce_semantic_zero_label = reduce_semantic_zero_label | |
# Get label map for semantic custom classes | |
new_classes = self._metainfo.get('semantic_classes', None) | |
self.semantic_label_map = self.get_semantic_label_map(new_classes) | |
self._metainfo.update( | |
dict( | |
semantic_label_map=self.semantic_label_map, | |
reduce_semantic_zero_label=self.reduce_semantic_zero_label)) | |
# Update palette based on label map or generate palette | |
# if it is not defined | |
updated_semantic_palette = self._update_semantic_palette() | |
self._metainfo.update(dict(semantic_palette=updated_semantic_palette)) | |
if not lazy_init: | |
self.full_init() | |
if self.test_mode: | |
assert self._metainfo.get('semantic_classes') is not None, \ | |
'dataset metainfo `semantic_classes` should be specified when testing' | |
def get_semantic_label_map(cls, | |
new_classes: Optional[Sequence] = None | |
) -> Union[Dict, None]: | |
"""Require semantic label mapping. | |
The ``label_map`` is a dictionary, its keys are the old label ids and | |
its values are the new label ids, and is used for changing pixel | |
labels in load_annotations. If and only if old classes in cls.METAINFO | |
is not equal to new classes in self._metainfo and nether of them is not | |
None, `label_map` is not None. | |
Args: | |
new_classes (list, tuple, optional): The new classes name from | |
metainfo. Default to None. | |
Returns: | |
dict, optional: The mapping from old classes in cls.METAINFO to | |
new classes in self._metainfo | |
""" | |
old_classes = cls.METAINFO.get('semantic_classes', None) | |
if (new_classes is not None and old_classes is not None | |
and list(new_classes) != list(old_classes)): | |
label_map = {} | |
if not set(new_classes).issubset(cls.METAINFO['semantic_classes']): | |
raise ValueError( | |
f'new semantic_classes {new_classes} is not a ' | |
f'subset of semantic_classes {old_classes} in METAINFO.') | |
for i, c in enumerate(old_classes): | |
if c not in new_classes: | |
label_map[i] = 255 | |
else: | |
label_map[i] = new_classes.index(c) | |
return label_map | |
else: | |
return None | |
def _update_semantic_palette(self) -> list: | |
"""Update palette after loading metainfo. | |
If length of palette is equal to classes, just return the palette. | |
If palette is not defined, it will randomly generate a palette. | |
If classes is updated by customer, it will return the subset of | |
palette. | |
Returns: | |
Sequence: Palette for current dataset. | |
""" | |
palette = self._metainfo.get('semantic_palette', []) | |
classes = self._metainfo.get('semantic_classes', []) | |
# palette does match classes | |
if len(palette) == len(classes): | |
return palette | |
if len(palette) == 0: | |
# Get random state before set seed, and restore | |
# random state later. | |
# It will prevent loss of randomness, as the palette | |
# may be different in each iteration if not specified. | |
# See: https://github.com/open-mmlab/mmdetection/issues/5844 | |
state = np.random.get_state() | |
np.random.seed(42) | |
# random palette | |
new_palette = np.random.randint( | |
0, 255, size=(len(classes), 3)).tolist() | |
np.random.set_state(state) | |
elif len(palette) >= len(classes) and self.semantic_label_map is not None: | |
new_palette = [] | |
# return subset of palette | |
for old_id, new_id in sorted( | |
self.semantic_label_map.items(), key=lambda x: x[1]): | |
if new_id != 255: | |
new_palette.append(palette[old_id]) | |
new_palette = type(palette)(new_palette) | |
else: | |
raise ValueError('palette does not match classes ' | |
f'as metainfo is {self._metainfo}.') | |
return new_palette | |
def load_data_list(self) -> List[dict]: | |
"""Load annotation from directory or annotation file. | |
Returns: | |
list[dict]: All data info of dataset. | |
""" | |
data_list = [] | |
img_dir_from = self.data_prefix.get('img_path_from', None) | |
img_dir_to = self.data_prefix.get('img_path_to', None) | |
ann_dir = self.data_prefix.get('seg_map_path', None) | |
ann_dir_from = self.data_prefix.get('seg_map_path_from', None) | |
ann_dir_to = self.data_prefix.get('seg_map_path_to', None) | |
if osp.isfile(self.ann_file): | |
lines = mmengine.list_from_file( | |
self.ann_file, backend_args=self.backend_args) | |
for line in lines: | |
data_names = line.strip().split(' ') | |
# img_name: img1, img2, binary label, semantic_from label, \ | |
# semantic_to label | |
img_name_from, img_name_to, ann_name, ann_name_from, \ | |
ann_name_to = data_names | |
data_info = dict(img_path=\ | |
[osp.join(img_dir_from, img_name_from + self.img_suffix), \ | |
osp.join(img_dir_to, img_name_to + self.img_suffix)]) | |
if ann_dir is not None: | |
seg_map = ann_name + self.seg_map_suffix | |
seg_map_from = ann_name_from + self.seg_map_suffix | |
seg_map_to = ann_name_to + self.seg_map_suffix | |
data_info['seg_map_path'] = osp.join(ann_dir, seg_map) | |
data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map_from) | |
data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map_to) | |
data_info['label_map'] = self.label_map | |
data_info['format_seg_map'] = self.format_seg_map | |
data_info['reduce_zero_label'] = self.reduce_zero_label | |
data_info['semantic_label_map'] = self.semantic_label_map | |
data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label | |
data_info['seg_fields'] = [] | |
data_list.append(data_info) | |
else: | |
file_list_from = fileio.list_dir_or_file( | |
dir_path=img_dir_from, | |
list_dir=False, | |
suffix=self.img_suffix, | |
recursive=True, | |
backend_args=self.backend_args) | |
file_list_to = fileio.list_dir_or_file( | |
dir_path=img_dir_to, | |
list_dir=False, | |
suffix=self.img_suffix, | |
recursive=True, | |
backend_args=self.backend_args) | |
assert sorted(list(file_list_from)) == sorted(list(file_list_to)), \ | |
'The images in `img_path_from` and `img_path_to` are not ' \ | |
'one-to-one correspondence' | |
for img in fileio.list_dir_or_file( | |
dir_path=img_dir_from, | |
list_dir=False, | |
suffix=self.img_suffix, | |
recursive=True, | |
backend_args=self.backend_args): | |
data_info = dict(img_path=\ | |
[osp.join(img_dir_from, img), \ | |
osp.join(img_dir_to, img)]) | |
if ann_dir is not None: | |
seg_map = img.replace(self.img_suffix, self.seg_map_suffix) | |
data_info['seg_map_path'] = osp.join(ann_dir, seg_map) | |
data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map) | |
data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map) | |
data_info['label_map'] = self.label_map | |
data_info['format_seg_map'] = self.format_seg_map | |
data_info['reduce_zero_label'] = self.reduce_zero_label | |
data_info['semantic_label_map'] = self.semantic_label_map | |
data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label | |
data_info['seg_fields'] = [] | |
data_list.append(data_info) | |
data_list = sorted(data_list, key=lambda x: x['img_path']) | |
return data_list | |