File size: 4,417 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import cv2
import numpy as np
import torch
from mmcv.ops import contour_expand
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .pan_postprocessor import PANPostprocessor


@MODELS.register_module()
class PSEPostprocessor(PANPostprocessor):
    """Decoding predictions of PSENet to instances. This is partially adapted
    from https://github.com/whai362/PSENet.

    Args:
        text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
            Defaults to 'poly'.
        rescale_fields (list[str]): The bbox/polygon field names to
            be rescaled. If None, no rescaling will be performed. Defaults to
            ['polygons'].
        min_kernel_confidence (float): The minimal kernel confidence.
            Defaults to 0.5.
        score_threshold (float): The minimal text average confidence.
            Defaults to 0.3.
        min_kernel_area (int): The minimal text kernel area. Defaults to 0.
        min_text_area (int): The minimal text instance region area.
            Defaults to 16.
        downsample_ratio (float): Downsample ratio. Defaults to 0.25.
    """

    def __init__(self,
                 text_repr_type: str = 'poly',
                 rescale_fields: List[str] = ['polygons'],
                 min_kernel_confidence: float = 0.5,
                 score_threshold: float = 0.3,
                 min_kernel_area: int = 0,
                 min_text_area: int = 16,
                 downsample_ratio: float = 0.25) -> None:
        super().__init__(
            text_repr_type=text_repr_type,
            rescale_fields=rescale_fields,
            min_kernel_confidence=min_kernel_confidence,
            score_threshold=score_threshold,
            min_text_area=min_text_area,
            downsample_ratio=downsample_ratio)
        self.min_kernel_area = min_kernel_area

    def get_text_instances(self, pred_results: torch.Tensor,
                           data_sample: TextDetDataSample,
                           **kwargs) -> TextDetDataSample:
        """
        Args:
            pred_result (torch.Tensor): Prediction results of an image which
                is a tensor of shape :math:`(N, H, W)`.
            data_sample (TextDetDataSample): Datasample of an image.

        Returns:
            TextDetDataSample: A new DataSample with predictions filled in.
            Polygons and results are saved in
            ``TextDetDataSample.pred_instances.polygons``. The confidence
            scores are saved in ``TextDetDataSample.pred_instances.scores``.
        """
        assert pred_results.dim() == 3

        pred_results = torch.sigmoid(pred_results)  # text confidence

        masks = pred_results > self.min_kernel_confidence
        text_mask = masks[0, :, :]
        kernel_masks = masks[0:, :, :] * text_mask
        kernel_masks = kernel_masks.data.cpu().numpy().astype(np.uint8)

        score = pred_results[0, :, :]
        score = score.data.cpu().numpy().astype(np.float32)

        region_num, labels = cv2.connectedComponents(
            kernel_masks[-1], connectivity=4)

        labels = contour_expand(kernel_masks, labels, self.min_kernel_area,
                                region_num)
        labels = np.array(labels)
        label_num = np.max(labels)

        polygons = []
        scores = []
        for i in range(1, label_num + 1):
            points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1]
            area = points.shape[0]
            score_instance = np.mean(score[labels == i])
            if not (area >= self.min_text_area
                    or score_instance > self.score_threshold):
                continue

            polygon = self._points2boundary(points)
            if polygon:
                polygons.append(polygon)
                scores.append(score_instance)

        pred_instances = InstanceData()
        pred_instances.polygons = polygons
        pred_instances.scores = torch.FloatTensor(scores)
        data_sample.pred_instances = pred_instances
        scale_factor = data_sample.scale_factor
        scale_factor = tuple(factor * self.downsample_ratio
                             for factor in scale_factor)
        data_sample.set_metainfo(dict(scale_factor=scale_factor))

        return data_sample