File size: 7,279 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) Open-CD. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Tuple

from mmengine.model import BaseModule
from mmengine.structures import PixelData
from torch import Tensor, nn

# from mmseg.models import builder
from mmseg.models.utils import resize
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, add_prefix
from opencd.registry import MODELS


@MODELS.register_module()
class MultiHeadDecoder(BaseModule):
    """Base class for MultiHeadDecoder.

    Args:
        binary_cd_head (dict): The decode head for binary change detection branch.
        binary_cd_neck (dict): The feature fusion part for binary \
            change detection branch
        semantic_cd_head (dict): The decode head for semantic change \
            detection `from` branch.
        semantic_cd_head_aux (dict): The decode head for semantic change \
            detection `to` branch. If None, the siamese semantic head will \
            be used. Default: None
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """

    def __init__(self,
                 binary_cd_head,
                 binary_cd_neck=None,
                 semantic_cd_head=None,
                 semantic_cd_head_aux=None,
                 init_cfg=None):
        super().__init__(init_cfg)
        self.binary_cd_head = MODELS.build(binary_cd_head)
        self.siamese_semantic_head = True
        if binary_cd_neck is not None:
            self.binary_cd_neck = MODELS.build(binary_cd_neck)
        if semantic_cd_head is not None:
            self.semantic_cd_head = MODELS.build(semantic_cd_head)
            if semantic_cd_head_aux is not None:
                self.siamese_semantic_head = False
                self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux)
            else:
                self.semantic_cd_head_aux = self.semantic_cd_head

    @abstractmethod
    def forward(self, inputs):
        """Placeholder of forward function.
        The return value should be a dict() containing: 
        `seg_logits`, `seg_logits_from` and `seg_logits_to`.
        
        For example:
            return dict(
                seg_logits=out,
                seg_logits_from=out1, 
                seg_logits_to=out2)
        """
        pass

    def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
             train_cfg: ConfigType) -> dict:
        """Forward function for training.

        Args:
            inputs (Tuple[Tensor]): List of multi-level img features.
            batch_data_samples (list[:obj:`SegDataSample`]): The seg
                data samples. It usually includes information such
                as `img_metas` or `gt_semantic_seg`.
            train_cfg (dict): The training config.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        seg_logits = self.forward(inputs)
        losses = self.loss_by_feat(seg_logits, batch_data_samples)
        return losses
    
    def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
                **kwargs) -> List[Tensor]:
        """Forward function for testing."""
        seg_logits = self.forward(inputs)
        return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
        
    def predict_by_feat(self, seg_logits: Tensor,
                        batch_img_metas: List[dict]) -> Tensor:
        """Transform a batch of output seg_logits to the input shape.

        Args:
            seg_logits (Tensor): The output from decode head forward function.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.

        Returns:
            Tensor: Outputs segmentation logits map.
        """
        assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
            == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
            and `seg_logits_to` should be contained."

        self.align_corners = {
            'seg_logits': self.binary_cd_head.align_corners,
            'seg_logits_from': self.semantic_cd_head.align_corners,
            'seg_logits_to': self.semantic_cd_head_aux.align_corners}

        for seg_name, seg_logit in seg_logits.items():
            seg_logits[seg_name] = resize(
                input=seg_logit,
                size=batch_img_metas[0]['img_shape'],
                mode='bilinear',
                align_corners=self.align_corners[seg_name])
        return seg_logits
    
    def get_sub_batch_data_samples(self, batch_data_samples: SampleList, 
                                   sub_metainfo_name: str,
                                   sub_data_name: str) -> list:
        sub_batch_sample_list = []
        for i in range(len(batch_data_samples)):
            data_sample = SegDataSample()

            gt_sem_seg_data = dict(
                data=batch_data_samples[i].get(sub_data_name).data)
            data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)

            img_meta = {}
            seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name)
            for key in batch_data_samples[i].metainfo.keys():
                if not 'seg_map_path' in key:
                    img_meta[key] = batch_data_samples[i].metainfo.get(key)
            img_meta['seg_map_path'] = seg_map_path
            data_sample.set_metainfo(img_meta)

            sub_batch_sample_list.append(data_sample)
        return sub_batch_sample_list

    def loss_by_feat(self, seg_logits: dict,
                     batch_data_samples: SampleList, **kwargs) -> dict:
        """Compute segmentation loss."""
        assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
            == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
            and `seg_logits_to` should be contained."

        losses = dict()
        binary_cd_loss_decode = self.binary_cd_head.loss_by_feat(
            seg_logits['seg_logits'],
            self.get_sub_batch_data_samples(batch_data_samples,
                                            sub_metainfo_name='seg_map_path',
                                            sub_data_name='gt_sem_seg'))
        losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd'))

        if getattr(self, 'semantic_cd_head'):
            semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat(
                seg_logits['seg_logits_from'],
                self.get_sub_batch_data_samples(batch_data_samples,
                                                sub_metainfo_name='seg_map_path_from',
                                                sub_data_name='gt_sem_seg_from'))
            losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from'))

            semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat(
                seg_logits['seg_logits_to'],
                self.get_sub_batch_data_samples(batch_data_samples,
                                                sub_metainfo_name='seg_map_path_to',
                                                sub_data_name='gt_sem_seg_to'))
            losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to'))

        return losses