Spaces:
Runtime error
Runtime error
File size: 3,866 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Tuple, Union
import torch.nn.functional as F
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptMultiConfig
@MODELS.register_module()
class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
"""Base module of Semantic Head.
Args:
num_classes (int): the number of classes.
seg_rescale_factor (float): the rescale factor for ``gt_sem_seg``,
which equals to ``1 / output_strides``. The output_strides is
for ``seg_preds``. Defaults to 1 / 4.
init_cfg (Optional[Union[:obj:`ConfigDict`, dict]]): the initialization
config.
loss_seg (Union[:obj:`ConfigDict`, dict]): the loss of the semantic
head.
"""
def __init__(self,
num_classes: int,
seg_rescale_factor: float = 1 / 4.,
loss_seg: ConfigType = dict(
type='CrossEntropyLoss',
ignore_index=255,
loss_weight=1.0),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss_seg = MODELS.build(loss_seg)
self.num_classes = num_classes
self.seg_rescale_factor = seg_rescale_factor
@abstractmethod
def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Dict[str, Tensor]:
"""Placeholder of forward function.
Args:
x (Tensor): Feature maps.
Returns:
Dict[str, Tensor]: A dictionary, including features
and predicted scores. Required keys: 'seg_preds'
and 'feats'.
"""
pass
@abstractmethod
def loss(self, x: Union[Tensor, Tuple[Tensor]],
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""
Args:
x (Union[Tensor, Tuple[Tensor]]): Feature maps.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Args:
x (Tensor): Feature maps.
Returns:
Dict[str, Tensor]: The loss of semantic head.
"""
pass
def predict(self,
x: Union[Tensor, Tuple[Tensor]],
batch_img_metas: List[dict],
rescale: bool = False) -> List[Tensor]:
"""Test without Augmentation.
Args:
x (Union[Tensor, Tuple[Tensor]]): Feature maps.
batch_img_metas (List[dict]): List of image information.
rescale (bool): Whether to rescale the results.
Defaults to False.
Returns:
list[Tensor]: semantic segmentation logits.
"""
seg_preds = self.forward(x)['seg_preds']
seg_preds = F.interpolate(
seg_preds,
size=batch_img_metas[0]['batch_input_shape'],
mode='bilinear',
align_corners=False)
seg_preds = [seg_preds[i] for i in range(len(batch_img_metas))]
if rescale:
seg_pred_list = []
for i in range(len(batch_img_metas)):
h, w = batch_img_metas[i]['img_shape']
seg_pred = seg_preds[i][:, :h, :w]
h, w = batch_img_metas[i]['ori_shape']
seg_pred = F.interpolate(
seg_pred[None],
size=(h, w),
mode='bilinear',
align_corners=False)[0]
seg_pred_list.append(seg_pred)
else:
seg_pred_list = seg_preds
return seg_pred_list
|