File size: 9,145 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright (c) Open-CD. All rights reserved.
import copy
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union

import mmengine
import mmengine.fileio as fileio
import numpy as np

from mmseg.registry import DATASETS
from .basecddataset import _BaseCDDataset


@DATASETS.register_module()
class BaseSCDDataset(_BaseCDDataset):
    def __init__(self,
                 lazy_init=False,
                 reduce_semantic_zero_label=False,
                 **kwargs):
        super().__init__(lazy_init=True, **kwargs)

        self.reduce_semantic_zero_label = reduce_semantic_zero_label

        # Get label map for semantic custom classes
        new_classes = self._metainfo.get('semantic_classes', None)
        self.semantic_label_map = self.get_semantic_label_map(new_classes)
        self._metainfo.update(
            dict(
                semantic_label_map=self.semantic_label_map,
                reduce_semantic_zero_label=self.reduce_semantic_zero_label))

        # Update palette based on label map or generate palette
        # if it is not defined
        updated_semantic_palette = self._update_semantic_palette()
        self._metainfo.update(dict(semantic_palette=updated_semantic_palette))

        if not lazy_init:
            self.full_init()

        if self.test_mode:
            assert self._metainfo.get('semantic_classes') is not None, \
                'dataset metainfo `semantic_classes` should be specified when testing'

    @classmethod
    def get_semantic_label_map(cls,
                      new_classes: Optional[Sequence] = None
                      ) -> Union[Dict, None]:
        """Require semantic label mapping.

        The ``label_map`` is a dictionary, its keys are the old label ids and
        its values are the new label ids, and is used for changing pixel
        labels in load_annotations. If and only if old classes in cls.METAINFO
        is not equal to new classes in self._metainfo and nether of them is not
        None, `label_map` is not None.

        Args:
            new_classes (list, tuple, optional): The new classes name from
                metainfo. Default to None.


        Returns:
            dict, optional: The mapping from old classes in cls.METAINFO to
                new classes in self._metainfo
        """
        old_classes = cls.METAINFO.get('semantic_classes', None)
        if (new_classes is not None and old_classes is not None
                and list(new_classes) != list(old_classes)):

            label_map = {}
            if not set(new_classes).issubset(cls.METAINFO['semantic_classes']):
                raise ValueError(
                    f'new semantic_classes {new_classes} is not a '
                    f'subset of semantic_classes {old_classes} in METAINFO.')
            for i, c in enumerate(old_classes):
                if c not in new_classes:
                    label_map[i] = 255
                else:
                    label_map[i] = new_classes.index(c)
            return label_map
        else:
            return None

    def _update_semantic_palette(self) -> list:
        """Update palette after loading metainfo.

        If length of palette is equal to classes, just return the palette.
        If palette is not defined, it will randomly generate a palette.
        If classes is updated by customer, it will return the subset of
        palette.

        Returns:
            Sequence: Palette for current dataset.
        """
        palette = self._metainfo.get('semantic_palette', [])
        classes = self._metainfo.get('semantic_classes', [])
        # palette does match classes
        if len(palette) == len(classes):
            return palette

        if len(palette) == 0:
            # Get random state before set seed, and restore
            # random state later.
            # It will prevent loss of randomness, as the palette
            # may be different in each iteration if not specified.
            # See: https://github.com/open-mmlab/mmdetection/issues/5844
            state = np.random.get_state()
            np.random.seed(42)
            # random palette
            new_palette = np.random.randint(
                0, 255, size=(len(classes), 3)).tolist()
            np.random.set_state(state)
        elif len(palette) >= len(classes) and self.semantic_label_map is not None:
            new_palette = []
            # return subset of palette
            for old_id, new_id in sorted(
                    self.semantic_label_map.items(), key=lambda x: x[1]):
                if new_id != 255:
                    new_palette.append(palette[old_id])
            new_palette = type(palette)(new_palette)
        else:
            raise ValueError('palette does not match classes '
                             f'as metainfo is {self._metainfo}.')
        return new_palette


    def load_data_list(self) -> List[dict]:
        """Load annotation from directory or annotation file.

        Returns:
            list[dict]: All data info of dataset.
        """
        data_list = []
        img_dir_from = self.data_prefix.get('img_path_from', None)
        img_dir_to = self.data_prefix.get('img_path_to', None)
        ann_dir = self.data_prefix.get('seg_map_path', None)
        ann_dir_from = self.data_prefix.get('seg_map_path_from', None)
        ann_dir_to = self.data_prefix.get('seg_map_path_to', None)

        if osp.isfile(self.ann_file):
            lines = mmengine.list_from_file(
                self.ann_file, backend_args=self.backend_args)
            for line in lines:
                data_names = line.strip().split(' ')
                # img_name: img1, img2, binary label, semantic_from label, \
                # semantic_to label
                img_name_from, img_name_to, ann_name, ann_name_from, \
                                                        ann_name_to = data_names

                data_info = dict(img_path=\
                                 [osp.join(img_dir_from, img_name_from + self.img_suffix), \
                                  osp.join(img_dir_to, img_name_to + self.img_suffix)])
                if ann_dir is not None:
                    seg_map = ann_name + self.seg_map_suffix
                    seg_map_from = ann_name_from + self.seg_map_suffix
                    seg_map_to = ann_name_to + self.seg_map_suffix
                    data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
                    data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map_from)
                    data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map_to)
                data_info['label_map'] = self.label_map
                data_info['format_seg_map'] = self.format_seg_map
                data_info['reduce_zero_label'] = self.reduce_zero_label
                data_info['semantic_label_map'] = self.semantic_label_map
                data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label
                data_info['seg_fields'] = []
                data_list.append(data_info)
        else:
            file_list_from = fileio.list_dir_or_file(
                    dir_path=img_dir_from,
                    list_dir=False,
                    suffix=self.img_suffix,
                    recursive=True,
                    backend_args=self.backend_args)
            file_list_to = fileio.list_dir_or_file(
                    dir_path=img_dir_to,
                    list_dir=False,
                    suffix=self.img_suffix,
                    recursive=True,
                    backend_args=self.backend_args)

            assert sorted(list(file_list_from)) == sorted(list(file_list_to)), \
                'The images in `img_path_from` and `img_path_to` are not ' \
                    'one-to-one correspondence'

            for img in fileio.list_dir_or_file(
                    dir_path=img_dir_from,
                    list_dir=False,
                    suffix=self.img_suffix,
                    recursive=True,
                    backend_args=self.backend_args):
                data_info = dict(img_path=\
                                 [osp.join(img_dir_from, img), \
                                  osp.join(img_dir_to, img)])
                if ann_dir is not None:
                    seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
                    data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
                    data_info['seg_map_path_from'] = osp.join(ann_dir_from, seg_map)
                    data_info['seg_map_path_to'] = osp.join(ann_dir_to, seg_map)
                data_info['label_map'] = self.label_map
                data_info['format_seg_map'] = self.format_seg_map
                data_info['reduce_zero_label'] = self.reduce_zero_label
                data_info['semantic_label_map'] = self.semantic_label_map
                data_info['reduce_semantic_zero_label'] = self.reduce_semantic_zero_label
                data_info['seg_fields'] = []
                data_list.append(data_info)
            data_list = sorted(data_list, key=lambda x: x['img_path'])
        return data_list