TTP / opencd /models /decode_heads /multi_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) Open-CD. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from mmengine.model import BaseModule
from mmengine.structures import PixelData
from torch import Tensor, nn
# from mmseg.models import builder
from mmseg.models.utils import resize
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, add_prefix
from opencd.registry import MODELS
@MODELS.register_module()
class MultiHeadDecoder(BaseModule):
"""Base class for MultiHeadDecoder.
Args:
binary_cd_head (dict): The decode head for binary change detection branch.
binary_cd_neck (dict): The feature fusion part for binary \
change detection branch
semantic_cd_head (dict): The decode head for semantic change \
detection `from` branch.
semantic_cd_head_aux (dict): The decode head for semantic change \
detection `to` branch. If None, the siamese semantic head will \
be used. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
binary_cd_head,
binary_cd_neck=None,
semantic_cd_head=None,
semantic_cd_head_aux=None,
init_cfg=None):
super().__init__(init_cfg)
self.binary_cd_head = MODELS.build(binary_cd_head)
self.siamese_semantic_head = True
if binary_cd_neck is not None:
self.binary_cd_neck = MODELS.build(binary_cd_neck)
if semantic_cd_head is not None:
self.semantic_cd_head = MODELS.build(semantic_cd_head)
if semantic_cd_head_aux is not None:
self.siamese_semantic_head = False
self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux)
else:
self.semantic_cd_head_aux = self.semantic_cd_head
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function.
The return value should be a dict() containing:
`seg_logits`, `seg_logits_from` and `seg_logits_to`.
For example:
return dict(
seg_logits=out,
seg_logits_from=out1,
seg_logits_to=out2)
"""
pass
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
**kwargs) -> List[Tensor]:
"""Forward function for testing."""
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
and `seg_logits_to` should be contained."
self.align_corners = {
'seg_logits': self.binary_cd_head.align_corners,
'seg_logits_from': self.semantic_cd_head.align_corners,
'seg_logits_to': self.semantic_cd_head_aux.align_corners}
for seg_name, seg_logit in seg_logits.items():
seg_logits[seg_name] = resize(
input=seg_logit,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners[seg_name])
return seg_logits
def get_sub_batch_data_samples(self, batch_data_samples: SampleList,
sub_metainfo_name: str,
sub_data_name: str) -> list:
sub_batch_sample_list = []
for i in range(len(batch_data_samples)):
data_sample = SegDataSample()
gt_sem_seg_data = dict(
data=batch_data_samples[i].get(sub_data_name).data)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
img_meta = {}
seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name)
for key in batch_data_samples[i].metainfo.keys():
if not 'seg_map_path' in key:
img_meta[key] = batch_data_samples[i].metainfo.get(key)
img_meta['seg_map_path'] = seg_map_path
data_sample.set_metainfo(img_meta)
sub_batch_sample_list.append(data_sample)
return sub_batch_sample_list
def loss_by_feat(self, seg_logits: dict,
batch_data_samples: SampleList, **kwargs) -> dict:
"""Compute segmentation loss."""
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
and `seg_logits_to` should be contained."
losses = dict()
binary_cd_loss_decode = self.binary_cd_head.loss_by_feat(
seg_logits['seg_logits'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path',
sub_data_name='gt_sem_seg'))
losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd'))
if getattr(self, 'semantic_cd_head'):
semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat(
seg_logits['seg_logits_from'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path_from',
sub_data_name='gt_sem_seg_from'))
losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from'))
semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat(
seg_logits['seg_logits_to'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path_to',
sub_data_name='gt_sem_seg_to'))
losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to'))
return losses