Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from mmocr.registry import MODELS | |
from mmocr.utils.typing_utils import DetSampleList | |
class BaseTextDetHead(BaseModule): | |
"""Base head for text detection, build the loss and postprocessor. | |
1. The ``init_weights`` method is used to initialize head's | |
model parameters. After detector initialization, ``init_weights`` | |
is triggered when ``detector.init_weights()`` is called externally. | |
2. The ``loss`` method is used to calculate the loss of head, | |
which includes two steps: (1) the head model performs forward | |
propagation to obtain the feature maps (2) The ``module_loss`` method | |
is called based on the feature maps to calculate the loss. | |
.. code:: text | |
loss(): forward() -> module_loss() | |
3. The ``predict`` method is used to predict detection results, | |
which includes two steps: (1) the head model performs forward | |
propagation to obtain the feature maps (2) The ``postprocessor`` method | |
is called based on the feature maps to predict detection results including | |
post-processing. | |
.. code:: text | |
predict(): forward() -> postprocessor() | |
4. The ``loss_and_predict`` method is used to return loss and detection | |
results at the same time. It will call head's ``forward``, | |
``module_loss`` and ``postprocessor`` methods in order. | |
.. code:: text | |
loss_and_predict(): forward() -> module_loss() -> postprocessor() | |
Args: | |
loss (dict, optional): Config to build loss. Defaults to None. | |
postprocessor (dict, optional): Config to build postprocessor. Defaults | |
to None. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to None. | |
""" | |
def __init__(self, | |
module_loss: Optional[Dict] = None, | |
postprocessor: Optional[Dict] = None, | |
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
if module_loss is not None: | |
assert isinstance(module_loss, dict) | |
self.module_loss = MODELS.build(module_loss) | |
else: | |
self.module_loss = module_loss | |
if postprocessor is not None: | |
assert isinstance(postprocessor, dict) | |
self.postprocessor = MODELS.build(postprocessor) | |
else: | |
self.postprocessor = postprocessor | |
def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict: | |
"""Perform forward propagation and loss calculation of the detection | |
head on the features of the upstream network. | |
Args: | |
x (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
data_samples (List[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
Returns: | |
dict: A dictionary of loss components. | |
""" | |
outs = self(x, data_samples) | |
losses = self.module_loss(outs, data_samples) | |
return losses | |
def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList | |
) -> Tuple[dict, DetSampleList]: | |
"""Perform forward propagation of the head, then calculate loss and | |
predictions from the features and data samples. | |
Args: | |
x (tuple[Tensor]): Features from FPN. | |
data_samples (list[:obj:`DetDataSample`]): Each item contains | |
the meta information of each image and corresponding | |
annotations. | |
Returns: | |
tuple: the return value is a tuple contains: | |
- losses: (dict[str, Tensor]): A dictionary of loss components. | |
- predictions (list[:obj:`InstanceData`]): Detection | |
results of each image after the post process. | |
""" | |
outs = self(x, data_samples) | |
losses = self.module_loss(outs, data_samples) | |
predictions = self.postprocessor(outs, data_samples, self.training) | |
return losses, predictions | |
def predict(self, x: torch.Tensor, | |
data_samples: DetSampleList) -> DetSampleList: | |
"""Perform forward propagation of the detection head and predict | |
detection results on the features of the upstream network. | |
Args: | |
x (tuple[Tensor]): Multi-level features from the | |
upstream network, each is a 4D-tensor. | |
data_samples (List[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
Returns: | |
SampleList: Detection results of each image | |
after the post process. | |
""" | |
outs = self(x, data_samples) | |
predictions = self.postprocessor(outs, data_samples) | |
return predictions | |