# Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod from typing import List from torch import Tensor from mmseg.utils import ConfigType from .decode_head import BaseDecodeHead class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): """Base class for cascade decode head used in :class:`CascadeEncoderDecoder.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @abstractmethod def forward(self, inputs, prev_output): """Placeholder of forward function.""" pass def loss(self, inputs: List[Tensor], prev_output: Tensor, batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor: """Forward function for training. Args: inputs (List[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. batch_data_samples (List[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `metainfo` and `gt_sem_seg`. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs, prev_output) losses = self.loss_by_feat(seg_logits, batch_data_samples) return losses def predict(self, inputs: List[Tensor], prev_output: Tensor, batch_img_metas: List[dict], tese_cfg: ConfigType): """Forward function for testing. Args: inputs (List[Tensor]): List of multi-level img features. prev_output (Tensor): The output of previous decode head. batch_img_metas (dict): List Image info where each dict may also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', 'ori_shape', and 'pad_shape'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ seg_logits = self.forward(inputs, prev_output) return self.predict_by_feat(seg_logits, batch_img_metas)