TTP / opencd /datasets /basescddataset.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) Open-CD. All rights reserved.
import copy
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
import mmengine
import mmengine.fileio as fileio
import numpy as np
from mmseg.registry import DATASETS
from .basecddataset import _BaseCDDataset
@DATASETS.register_module()
class BaseSCDDataset(_BaseCDDataset):
def __init__(self,
lazy_init=False,
reduce_semantic_zero_label=False,
**kwargs):
super().__init__(lazy_init=True, **kwargs)
self.reduce_semantic_zero_label = reduce_semantic_zero_label
# Get label map for semantic custom classes
new_classes = self._metainfo.get('semantic_classes', None)
self.semantic_label_map = self.get_semantic_label_map(new_classes)
self._metainfo.update(
dict(
semantic_label_map=self.semantic_label_map,
reduce_semantic_zero_label=self.reduce_semantic_zero_label))
# Update palette based on label map or generate palette
# if it is not defined
updated_semantic_palette = self._update_semantic_palette()
self._metainfo.update(dict(semantic_palette=updated_semantic_palette))
if not lazy_init:
self.full_init()
if self.test_mode:
assert self._metainfo.get('semantic_classes') is not None, \
'dataset metainfo `semantic_classes` should be specified when testing'
@classmethod
def get_semantic_label_map(cls,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require semantic label mapping.
The ``label_map`` is a dictionary, its keys are the old label ids and
its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_map` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
dict, optional: The mapping from old classes in cls.METAINFO to
new classes in self._metainfo
"""
old_classes = cls.METAINFO.get('semantic_classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_map = {}
if not set(new_classes).issubset(cls.METAINFO['semantic_classes']):
raise ValueError(
f'new semantic_classes {new_classes} is not a '
f'subset of semantic_classes {old_classes} in METAINFO.')
for i, c in enumerate(old_classes):
if c not in new_classes:
label_map[i] = 255
else:
label_map[i] = new_classes.index(c)
return label_map
else:
return None
def _update_semantic_palette(self) -> list:
"""Update palette after loading metainfo.
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns:
Sequence: Palette for current dataset.
"""
palette = self._metainfo.get('semantic_palette', [])
classes = self._metainfo.get('semantic_classes', [])
# palette does match classes
if len(palette) == len(classes):
return palette
if len(palette) == 0:
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
new_palette = np.random.randint(
0, 255, size=(len(classes), 3)).tolist()
np.random.set_state(state)
elif len(palette) >= len(classes) and self.semantic_label_map is not None:
new_palette = []
# return subset of palette
for old_id, new_id in sorted(
self.semantic_label_map.items(), key=lambda x: x[1]):
if new_id != 255:
new_palette.append(palette[old_id])
new_palette = type(palette)(new_palette)
else:
raise ValueError('palette does not match classes '
f'as metainfo is {self._metainfo}.')
return new_palette
def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
Returns:
list[dict]: All data info of dataset.
"""
data_list = []
img_dir_from = self.data_prefix.get('img_path_from', None)
img_dir_to = self.data_prefix.get('img_path_to', None)
ann_dir = self.data_prefix.get('seg_map_path', None)
ann_dir_from = self.data_prefix.get('seg_map_path_from', None)
ann_dir_to = self.data_prefix.get('seg_map_path_to', None)
if osp.isfile(self.ann_file):
lines = mmengine.list_from_file(
self.ann_file, backend_args=self.backend_args)
for line in lines:
data_names = line.strip().split(' ')
# img_name: img1, img2, binary label, semantic_from label, \
# semantic_to label
img_name_from, img_name_to, ann_name, ann_name_from, \
ann_name_to = data_names
data_info = dict(img_path=\
[osp.join(img_dir_from, img_name_from + self.img_suffix), \
osp.join(img_dir_to, img_name_to + self.img_suffix)])
if ann_dir is not None:
seg_map = ann_name + self.seg_map_suffix
seg_map_from = ann_name_from + self.seg_map_suffix
seg_map_to = ann_name_to + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map_from)
data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map_to)
data_info['label_map'] = self.label_map
data_info['format_seg_map'] = self.format_seg_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['semantic_label_map'] = self.semantic_label_map
data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
else:
file_list_from = fileio.list_dir_or_file(
dir_path=img_dir_from,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args)
file_list_to = fileio.list_dir_or_file(
dir_path=img_dir_to,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args)
assert sorted(list(file_list_from)) == sorted(list(file_list_to)), \
'The images in `img_path_from` and `img_path_to` are not ' \
'one-to-one correspondence'
for img in fileio.list_dir_or_file(
dir_path=img_dir_from,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args):
data_info = dict(img_path=\
[osp.join(img_dir_from, img), \
osp.join(img_dir_to, img)])
if ann_dir is not None:
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map)
data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map)
data_info['label_map'] = self.label_map
data_info['format_seg_map'] = self.format_seg_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['semantic_label_map'] = self.semantic_label_map
data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
return data_list