# Copyright (c) Open-CD. All rights reserved. from abc import ABCMeta, abstractmethod from typing import List, Tuple from mmengine.model import BaseModule from mmengine.structures import PixelData from torch import Tensor, nn # from mmseg.models import builder from mmseg.models.utils import resize from mmseg.structures import SegDataSample from mmseg.utils import ConfigType, SampleList, add_prefix from opencd.registry import MODELS @MODELS.register_module() class MultiHeadDecoder(BaseModule): """Base class for MultiHeadDecoder. Args: binary_cd_head (dict): The decode head for binary change detection branch. binary_cd_neck (dict): The feature fusion part for binary \ change detection branch semantic_cd_head (dict): The decode head for semantic change \ detection `from` branch. semantic_cd_head_aux (dict): The decode head for semantic change \ detection `to` branch. If None, the siamese semantic head will \ be used. Default: None init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, binary_cd_head, binary_cd_neck=None, semantic_cd_head=None, semantic_cd_head_aux=None, init_cfg=None): super().__init__(init_cfg) self.binary_cd_head = MODELS.build(binary_cd_head) self.siamese_semantic_head = True if binary_cd_neck is not None: self.binary_cd_neck = MODELS.build(binary_cd_neck) if semantic_cd_head is not None: self.semantic_cd_head = MODELS.build(semantic_cd_head) if semantic_cd_head_aux is not None: self.siamese_semantic_head = False self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux) else: self.semantic_cd_head_aux = self.semantic_cd_head @abstractmethod def forward(self, inputs): """Placeholder of forward function. The return value should be a dict() containing: `seg_logits`, `seg_logits_from` and `seg_logits_to`. For example: return dict( seg_logits=out, seg_logits_from=out1, seg_logits_to=out2) """ pass def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, train_cfg: ConfigType) -> dict: """Forward function for training. Args: inputs (Tuple[Tensor]): List of multi-level img features. batch_data_samples (list[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `img_metas` or `gt_semantic_seg`. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs) losses = self.loss_by_feat(seg_logits, batch_data_samples) return losses def predict(self, inputs, batch_img_metas: List[dict], test_cfg, **kwargs) -> List[Tensor]: """Forward function for testing.""" seg_logits = self.forward(inputs) return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict]) -> Tensor: """Transform a batch of output seg_logits to the input shape. Args: seg_logits (Tensor): The output from decode head forward function. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. Returns: Tensor: Outputs segmentation logits map. """ assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ and `seg_logits_to` should be contained." self.align_corners = { 'seg_logits': self.binary_cd_head.align_corners, 'seg_logits_from': self.semantic_cd_head.align_corners, 'seg_logits_to': self.semantic_cd_head_aux.align_corners} for seg_name, seg_logit in seg_logits.items(): seg_logits[seg_name] = resize( input=seg_logit, size=batch_img_metas[0]['img_shape'], mode='bilinear', align_corners=self.align_corners[seg_name]) return seg_logits def get_sub_batch_data_samples(self, batch_data_samples: SampleList, sub_metainfo_name: str, sub_data_name: str) -> list: sub_batch_sample_list = [] for i in range(len(batch_data_samples)): data_sample = SegDataSample() gt_sem_seg_data = dict( data=batch_data_samples[i].get(sub_data_name).data) data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) img_meta = {} seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name) for key in batch_data_samples[i].metainfo.keys(): if not 'seg_map_path' in key: img_meta[key] = batch_data_samples[i].metainfo.get(key) img_meta['seg_map_path'] = seg_map_path data_sample.set_metainfo(img_meta) sub_batch_sample_list.append(data_sample) return sub_batch_sample_list def loss_by_feat(self, seg_logits: dict, batch_data_samples: SampleList, **kwargs) -> dict: """Compute segmentation loss.""" assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ == list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ and `seg_logits_to` should be contained." losses = dict() binary_cd_loss_decode = self.binary_cd_head.loss_by_feat( seg_logits['seg_logits'], self.get_sub_batch_data_samples(batch_data_samples, sub_metainfo_name='seg_map_path', sub_data_name='gt_sem_seg')) losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd')) if getattr(self, 'semantic_cd_head'): semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat( seg_logits['seg_logits_from'], self.get_sub_batch_data_samples(batch_data_samples, sub_metainfo_name='seg_map_path_from', sub_data_name='gt_sem_seg_from')) losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from')) semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat( seg_logits['seg_logits_to'], self.get_sub_batch_data_samples(batch_data_samples, sub_metainfo_name='seg_map_path_to', sub_data_name='gt_sem_seg_to')) losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to')) return losses