Spaces:
Runtime error
Runtime error
File size: 4,989 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.structures import SampleList
from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig
from ..utils import unpack_gt_instances
class BaseMaskHead(BaseModule, metaclass=ABCMeta):
"""Base class for mask heads used in One-Stage Instance Segmentation."""
def __init__(self, init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
@abstractmethod
def loss_by_feat(self, *args, **kwargs):
"""Calculate the loss based on the features extracted by the mask
head."""
pass
@abstractmethod
def predict_by_feat(self, *args, **kwargs):
"""Transform a batch of output features extracted from the head into
mask results."""
pass
def loss(self,
x: Union[List[Tensor], Tuple[Tensor]],
batch_data_samples: SampleList,
positive_infos: OptInstanceList = None,
**kwargs) -> dict:
"""Perform forward propagation and loss calculation of the mask head on
the features of the upstream network.
Args:
x (list[Tensor] | tuple[Tensor]): Features from FPN.
Each has a shape (B, C, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
positive_infos (list[:obj:`InstanceData`], optional): Information
of positive samples. Used when the label assignment is
done outside the MaskHead, e.g., BboxHead in
YOLACT or CondInst, etc. When the label assignment is done in
MaskHead, it would be None, like SOLO or SOLOv2. All values
in it should have shape (num_positive_samples, *).
Returns:
dict: A dictionary of loss components.
"""
if positive_infos is None:
outs = self(x)
else:
outs = self(x, positive_infos)
assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
'even if only one item is returned'
outputs = unpack_gt_instances(batch_data_samples)
batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
= outputs
for gt_instances, img_metas in zip(batch_gt_instances,
batch_img_metas):
img_shape = img_metas['batch_input_shape']
gt_masks = gt_instances.masks.pad(img_shape)
gt_instances.masks = gt_masks
losses = self.loss_by_feat(
*outs,
batch_gt_instances=batch_gt_instances,
batch_img_metas=batch_img_metas,
positive_infos=positive_infos,
batch_gt_instances_ignore=batch_gt_instances_ignore,
**kwargs)
return losses
def predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = False,
results_list: OptInstanceList = None,
**kwargs) -> InstanceList:
"""Test function without test-time augmentation.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
results_list (list[obj:`InstanceData`], optional): Detection
results of each image after the post process. Only exist
if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
Returns:
list[obj:`InstanceData`]: Instance segmentation
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance,)
- labels (Tensor): Has a shape (num_instances,).
- masks (Tensor): Processed mask results, has a
shape (num_instances, h, w).
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
if results_list is None:
outs = self(x)
else:
outs = self(x, results_list)
results_list = self.predict_by_feat(
*outs,
batch_img_metas=batch_img_metas,
rescale=rescale,
results_list=results_list,
**kwargs)
return results_list
|