Spaces:
Running
Running
File size: 5,117 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.model import BaseModule
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.utils.typing_utils import DetSampleList
@MODELS.register_module()
class BaseTextDetHead(BaseModule):
"""Base head for text detection, build the loss and postprocessor.
1. The ``init_weights`` method is used to initialize head's
model parameters. After detector initialization, ``init_weights``
is triggered when ``detector.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of head,
which includes two steps: (1) the head model performs forward
propagation to obtain the feature maps (2) The ``module_loss`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> module_loss()
3. The ``predict`` method is used to predict detection results,
which includes two steps: (1) the head model performs forward
propagation to obtain the feature maps (2) The ``postprocessor`` method
is called based on the feature maps to predict detection results including
post-processing.
.. code:: text
predict(): forward() -> postprocessor()
4. The ``loss_and_predict`` method is used to return loss and detection
results at the same time. It will call head's ``forward``,
``module_loss`` and ``postprocessor`` methods in order.
.. code:: text
loss_and_predict(): forward() -> module_loss() -> postprocessor()
Args:
loss (dict, optional): Config to build loss. Defaults to None.
postprocessor (dict, optional): Config to build postprocessor. Defaults
to None.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
module_loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
if module_loss is not None:
assert isinstance(module_loss, dict)
self.module_loss = MODELS.build(module_loss)
else:
self.module_loss = module_loss
if postprocessor is not None:
assert isinstance(postprocessor, dict)
self.postprocessor = MODELS.build(postprocessor)
else:
self.postprocessor = postprocessor
def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
outs = self(x, data_samples)
losses = self.module_loss(outs, data_samples)
return losses
def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList
) -> Tuple[dict, DetSampleList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples.
Args:
x (tuple[Tensor]): Features from FPN.
data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
outs = self(x, data_samples)
losses = self.module_loss(outs, data_samples)
predictions = self.postprocessor(outs, data_samples, self.training)
return losses, predictions
def predict(self, x: torch.Tensor,
data_samples: DetSampleList) -> DetSampleList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
SampleList: Detection results of each image
after the post process.
"""
outs = self(x, data_samples)
predictions = self.postprocessor(outs, data_samples)
return predictions
|