Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmocr.models.textdet.postprocessors.base import BaseTextDetPostProcessor | |
from mmocr.registry import MODELS | |
from ..utils import bezier2poly | |
class ABCNetPostprocessor(BaseTextDetPostProcessor): | |
"""Post-processing methods for ABCNet. | |
Args: | |
num_classes (int): Number of classes. | |
use_sigmoid_cls (bool): Whether to use sigmoid for classification. | |
strides (tuple): Strides of each feature map. | |
norm_by_strides (bool): Whether to normalize the regression targets by | |
the strides. | |
bbox_coder (dict): Config dict for bbox coder. | |
text_repr_type (str): Text representation type, 'poly' or 'quad'. | |
with_bezier (bool): Whether to use bezier curve for text detection. | |
train_cfg (dict): Config dict for training. | |
test_cfg (dict): Config dict for testing. | |
""" | |
def __init__( | |
self, | |
text_repr_type='poly', | |
rescale_fields=['beziers', 'polygons'], | |
): | |
super().__init__( | |
text_repr_type=text_repr_type, rescale_fields=rescale_fields) | |
def merge_predict(self, spotting_data_samples, recog_data_samples): | |
texts = [ds.pred_text.item for ds in recog_data_samples] | |
start = 0 | |
for spotting_data_sample in spotting_data_samples: | |
end = start + len(spotting_data_sample.pred_instances) | |
spotting_data_sample.pred_instances.texts = texts[start:end] | |
start = end | |
return spotting_data_samples | |
# TODO: fix docstr | |
def __call__(self, | |
spotting_data_samples, | |
recog_data_samples, | |
training: bool = False): | |
"""Postprocess pred_results according to metainfos in data_samples. | |
Args: | |
pred_results (Union[Tensor, List[Tensor]]): The prediction results | |
stored in a tensor or a list of tensor. Usually each item to | |
be post-processed is expected to be a batched tensor. | |
data_samples (list[TextDetDataSample]): Batch of data_samples, | |
each corresponding to a prediction result. | |
training (bool): Whether the model is in training mode. Defaults to | |
False. | |
Returns: | |
list[TextDetDataSample]: Batch of post-processed datasamples. | |
""" | |
spotting_data_samples = list( | |
map(self._process_single, spotting_data_samples)) | |
return self.merge_predict(spotting_data_samples, recog_data_samples) | |
def _process_single(self, data_sample): | |
"""Process prediction results from one image. | |
Args: | |
pred_result (Union[Tensor, List[Tensor]]): Prediction results of an | |
image. | |
data_sample (TextDetDataSample): Datasample of an image. | |
""" | |
data_sample = self.get_text_instances(data_sample) | |
if self.rescale_fields and len(self.rescale_fields) > 0: | |
assert isinstance(self.rescale_fields, list) | |
assert set(self.rescale_fields).issubset( | |
set(data_sample.pred_instances.keys())) | |
data_sample = self.rescale(data_sample, data_sample.scale_factor) | |
return data_sample | |
def get_text_instances(self, data_sample, **kwargs): | |
"""Get text instance predictions of one image. | |
Args: | |
pred_result (tuple(Tensor)): Prediction results of an image. | |
data_sample (TextDetDataSample): Datasample of an image. | |
**kwargs: Other parameters. Configurable via ``__init__.train_cfg`` | |
and ``__init__.test_cfg``. | |
Returns: | |
TextDetDataSample: A new DataSample with predictions filled in. | |
The polygon/bbox results are usually saved in | |
``TextDetDataSample.pred_instances.polygons`` or | |
``TextDetDataSample.pred_instances.bboxes``. The confidence scores | |
are saved in ``TextDetDataSample.pred_instances.scores``. | |
""" | |
data_sample = data_sample.cpu().numpy() | |
pred_instances = data_sample.pred_instances | |
data_sample.pred_instances.polygons = list( | |
map(bezier2poly, pred_instances.beziers)) | |
return data_sample | |