from torch import nn from mmseg.core import add_prefix from mmseg.ops import resize from .. import builder from ..builder import SEGMENTORS from .encoder_decoder import EncoderDecoder @SEGMENTORS.register_module() class CascadeEncoderDecoder(EncoderDecoder): """Cascade Encoder Decoder segmentors. CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of CascadeEncoderDecoder are cascaded. The output of previous decoder_head will be the input of next decoder_head. """ def __init__(self, num_stages, backbone, decode_head, neck=None, auxiliary_head=None, train_cfg=None, test_cfg=None, pretrained=None): self.num_stages = num_stages super(CascadeEncoderDecoder, self).__init__( backbone=backbone, decode_head=decode_head, neck=neck, auxiliary_head=auxiliary_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained) def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" assert isinstance(decode_head, list) assert len(decode_head) == self.num_stages self.decode_head = nn.ModuleList() for i in range(self.num_stages): self.decode_head.append(builder.build_head(decode_head[i])) self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes def init_weights(self, pretrained=None): """Initialize the weights in backbone and heads. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ self.backbone.init_weights(pretrained=pretrained) for i in range(self.num_stages): self.decode_head[i].init_weights() if self.with_auxiliary_head: if isinstance(self.auxiliary_head, nn.ModuleList): for aux_head in self.auxiliary_head: aux_head.init_weights() else: self.auxiliary_head.init_weights() def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(img) out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) for i in range(1, self.num_stages): out = self.decode_head[i].forward_test(x, out, img_metas, self.test_cfg) out = resize( input=out, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners) return out def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head[0].forward_train( x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode_0')) for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. prev_outputs = self.decode_head[i - 1].forward_test( x, img_metas, self.test_cfg) loss_decode = self.decode_head[i].forward_train( x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, f'decode_{i}')) return losses