File size: 5,117 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
from mmengine.model import BaseModule
from torch import Tensor

from mmocr.registry import MODELS
from mmocr.utils.typing_utils import DetSampleList


@MODELS.register_module()
class BaseTextDetHead(BaseModule):
    """Base head for text detection, build the loss and postprocessor.

    1. The ``init_weights`` method is used to initialize head's
    model parameters. After detector initialization, ``init_weights``
    is triggered when ``detector.init_weights()`` is called externally.

    2. The ``loss`` method is used to calculate the loss of head,
    which includes two steps: (1) the head model performs forward
    propagation to obtain the feature maps (2) The ``module_loss`` method
    is called based on the feature maps to calculate the loss.

    .. code:: text

        loss(): forward() -> module_loss()

    3. The ``predict`` method is used to predict detection results,
    which includes two steps: (1) the head model performs forward
    propagation to obtain the feature maps (2) The ``postprocessor`` method
    is called based on the feature maps to predict detection results including
    post-processing.

    .. code:: text

        predict(): forward() -> postprocessor()

    4. The ``loss_and_predict`` method is used to return loss and detection
    results at the same time. It will call head's ``forward``,
    ``module_loss`` and ``postprocessor`` methods in order.

    .. code:: text

        loss_and_predict(): forward() -> module_loss() -> postprocessor()


    Args:
        loss (dict, optional): Config to build loss. Defaults to None.
        postprocessor (dict, optional): Config to build postprocessor. Defaults
            to None.
        init_cfg (dict or list[dict], optional): Initialization configs.
            Defaults to None.
    """

    def __init__(self,
                 module_loss: Optional[Dict] = None,
                 postprocessor: Optional[Dict] = None,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
        super().__init__(init_cfg=init_cfg)
        if module_loss is not None:
            assert isinstance(module_loss, dict)
            self.module_loss = MODELS.build(module_loss)
        else:
            self.module_loss = module_loss
        if postprocessor is not None:
            assert isinstance(postprocessor, dict)
            self.postprocessor = MODELS.build(postprocessor)
        else:
            self.postprocessor = postprocessor

    def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict:
        """Perform forward propagation and loss calculation of the detection
        head on the features of the upstream network.

        Args:
            x (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
            data_samples (List[:obj:`DetDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.

        Returns:
            dict: A dictionary of loss components.
        """
        outs = self(x, data_samples)
        losses = self.module_loss(outs, data_samples)
        return losses

    def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList
                         ) -> Tuple[dict, DetSampleList]:
        """Perform forward propagation of the head, then calculate loss and
        predictions from the features and data samples.

        Args:
            x (tuple[Tensor]): Features from FPN.
            data_samples (list[:obj:`DetDataSample`]): Each item contains
                the meta information of each image and corresponding
                annotations.

        Returns:
            tuple: the return value is a tuple contains:

                - losses: (dict[str, Tensor]): A dictionary of loss components.
                - predictions (list[:obj:`InstanceData`]): Detection
                  results of each image after the post process.
        """
        outs = self(x, data_samples)
        losses = self.module_loss(outs, data_samples)

        predictions = self.postprocessor(outs, data_samples, self.training)
        return losses, predictions

    def predict(self, x: torch.Tensor,
                data_samples: DetSampleList) -> DetSampleList:
        """Perform forward propagation of the detection head and predict
        detection results on the features of the upstream network.

        Args:
            x (tuple[Tensor]): Multi-level features from the
                upstream network, each is a 4D-tensor.
            data_samples (List[:obj:`DetDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.

        Returns:
            SampleList: Detection results of each image
            after the post process.
        """
        outs = self(x, data_samples)

        predictions = self.postprocessor(outs, data_samples)
        return predictions