Spaces:
Runtime error
Runtime error
File size: 7,279 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) Open-CD. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from mmengine.model import BaseModule
from mmengine.structures import PixelData
from torch import Tensor, nn
# from mmseg.models import builder
from mmseg.models.utils import resize
from mmseg.structures import SegDataSample
from mmseg.utils import ConfigType, SampleList, add_prefix
from opencd.registry import MODELS
@MODELS.register_module()
class MultiHeadDecoder(BaseModule):
"""Base class for MultiHeadDecoder.
Args:
binary_cd_head (dict): The decode head for binary change detection branch.
binary_cd_neck (dict): The feature fusion part for binary \
change detection branch
semantic_cd_head (dict): The decode head for semantic change \
detection `from` branch.
semantic_cd_head_aux (dict): The decode head for semantic change \
detection `to` branch. If None, the siamese semantic head will \
be used. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
binary_cd_head,
binary_cd_neck=None,
semantic_cd_head=None,
semantic_cd_head_aux=None,
init_cfg=None):
super().__init__(init_cfg)
self.binary_cd_head = MODELS.build(binary_cd_head)
self.siamese_semantic_head = True
if binary_cd_neck is not None:
self.binary_cd_neck = MODELS.build(binary_cd_neck)
if semantic_cd_head is not None:
self.semantic_cd_head = MODELS.build(semantic_cd_head)
if semantic_cd_head_aux is not None:
self.siamese_semantic_head = False
self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux)
else:
self.semantic_cd_head_aux = self.semantic_cd_head
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function.
The return value should be a dict() containing:
`seg_logits`, `seg_logits_from` and `seg_logits_to`.
For example:
return dict(
seg_logits=out,
seg_logits_from=out1,
seg_logits_to=out2)
"""
pass
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
**kwargs) -> List[Tensor]:
"""Forward function for testing."""
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
and `seg_logits_to` should be contained."
self.align_corners = {
'seg_logits': self.binary_cd_head.align_corners,
'seg_logits_from': self.semantic_cd_head.align_corners,
'seg_logits_to': self.semantic_cd_head_aux.align_corners}
for seg_name, seg_logit in seg_logits.items():
seg_logits[seg_name] = resize(
input=seg_logit,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners[seg_name])
return seg_logits
def get_sub_batch_data_samples(self, batch_data_samples: SampleList,
sub_metainfo_name: str,
sub_data_name: str) -> list:
sub_batch_sample_list = []
for i in range(len(batch_data_samples)):
data_sample = SegDataSample()
gt_sem_seg_data = dict(
data=batch_data_samples[i].get(sub_data_name).data)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
img_meta = {}
seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name)
for key in batch_data_samples[i].metainfo.keys():
if not 'seg_map_path' in key:
img_meta[key] = batch_data_samples[i].metainfo.get(key)
img_meta['seg_map_path'] = seg_map_path
data_sample.set_metainfo(img_meta)
sub_batch_sample_list.append(data_sample)
return sub_batch_sample_list
def loss_by_feat(self, seg_logits: dict,
batch_data_samples: SampleList, **kwargs) -> dict:
"""Compute segmentation loss."""
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \
and `seg_logits_to` should be contained."
losses = dict()
binary_cd_loss_decode = self.binary_cd_head.loss_by_feat(
seg_logits['seg_logits'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path',
sub_data_name='gt_sem_seg'))
losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd'))
if getattr(self, 'semantic_cd_head'):
semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat(
seg_logits['seg_logits_from'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path_from',
sub_data_name='gt_sem_seg_from'))
losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from'))
semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat(
seg_logits['seg_logits_to'],
self.get_sub_batch_data_samples(batch_data_samples,
sub_metainfo_name='seg_map_path_to',
sub_data_name='gt_sem_seg_to'))
losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to'))
return losses |