Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
from mmcv.transforms import to_tensor | |
from mmcv.transforms.base import BaseTransform | |
from mmengine.structures import InstanceData, LabelData | |
from mmocr.registry import TRANSFORMS | |
from mmocr.structures import (KIEDataSample, TextDetDataSample, | |
TextRecogDataSample) | |
class PackTextDetInputs(BaseTransform): | |
"""Pack the inputs data for text detection. | |
The type of outputs is `dict`: | |
- inputs: image converted to tensor, whose shape is (C, H, W). | |
- data_samples: Two components of ``TextDetDataSample`` will be updated: | |
- gt_instances (InstanceData): Depending on annotations, a subset of the | |
following keys will be updated: | |
- bboxes (torch.Tensor((N, 4), dtype=torch.float32)): The groundtruth | |
of bounding boxes in the form of [x1, y1, x2, y2]. Renamed from | |
'gt_bboxes'. | |
- labels (torch.LongTensor(N)): The labels of instances. | |
Renamed from 'gt_bboxes_labels'. | |
- polygons(list[np.array((2k,), dtype=np.float32)]): The | |
groundtruth of polygons in the form of [x1, y1,..., xk, yk]. Each | |
element in polygons may have different number of points. Renamed from | |
'gt_polygons'. Using numpy instead of tensor is that polygon usually | |
is not the output of model and operated on cpu. | |
- ignored (torch.BoolTensor((N,))): The flag indicating whether the | |
corresponding instance should be ignored. Renamed from | |
'gt_ignored'. | |
- texts (list[str]): The groundtruth texts. Renamed from 'gt_texts'. | |
- metainfo (dict): 'metainfo' is always populated. The contents of the | |
'metainfo' depends on ``meta_keys``. By default it includes: | |
- "img_path": Path to the image file. | |
- "img_shape": Shape of the image input to the network as a tuple | |
(h, w). Note that the image may be zero-padded afterward on the | |
bottom/right if the batch tensor is larger than this shape. | |
- "scale_factor": A tuple indicating the ratio of width and height | |
of the preprocessed image to the original one. | |
- "ori_shape": Shape of the preprocessed image as a tuple | |
(h, w). | |
- "pad_shape": Image shape after padding (if any Pad-related | |
transform involved) as a tuple (h, w). | |
- "flip": A boolean indicating if the image has been flipped. | |
- ``flip_direction``: the flipping direction. | |
Args: | |
meta_keys (Sequence[str], optional): Meta keys to be converted to | |
the metainfo of ``TextDetSample``. Defaults to ``('img_path', | |
'ori_shape', 'img_shape', 'scale_factor', 'flip', | |
'flip_direction')``. | |
""" | |
mapping_table = { | |
'gt_bboxes': 'bboxes', | |
'gt_bboxes_labels': 'labels', | |
'gt_polygons': 'polygons', | |
'gt_texts': 'texts', | |
'gt_ignored': 'ignored' | |
} | |
def __init__(self, | |
meta_keys=('img_path', 'ori_shape', 'img_shape', | |
'scale_factor', 'flip', 'flip_direction')): | |
self.meta_keys = meta_keys | |
def transform(self, results: dict) -> dict: | |
"""Method to pack the input data. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Returns: | |
dict: | |
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding. | |
- 'data_samples' (obj:`DetDataSample`): The annotation info of the | |
sample. | |
""" | |
packed_results = dict() | |
if 'img' in results: | |
img = results['img'] | |
if len(img.shape) < 3: | |
img = np.expand_dims(img, -1) | |
# A simple trick to speedup formatting by 3-5 times when | |
# OMP_NUM_THREADS != 1 | |
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533 | |
# for more details | |
if img.flags.c_contiguous: | |
img = to_tensor(img) | |
img = img.permute(2, 0, 1).contiguous() | |
else: | |
img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |
img = to_tensor(img) | |
packed_results['inputs'] = img | |
data_sample = TextDetDataSample() | |
instance_data = InstanceData() | |
for key in self.mapping_table.keys(): | |
if key not in results: | |
continue | |
if key in ['gt_bboxes', 'gt_bboxes_labels', 'gt_ignored']: | |
instance_data[self.mapping_table[key]] = to_tensor( | |
results[key]) | |
else: | |
instance_data[self.mapping_table[key]] = results[key] | |
data_sample.gt_instances = instance_data | |
img_meta = {} | |
for key in self.meta_keys: | |
img_meta[key] = results[key] | |
data_sample.set_metainfo(img_meta) | |
packed_results['data_samples'] = data_sample | |
return packed_results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(meta_keys={self.meta_keys})' | |
return repr_str | |
class PackTextRecogInputs(BaseTransform): | |
"""Pack the inputs data for text recognition. | |
The type of outputs is `dict`: | |
- inputs: Image as a tensor, whose shape is (C, H, W). | |
- data_samples: Two components of ``TextRecogDataSample`` will be updated: | |
- gt_text (LabelData): | |
- item(str): The groundtruth of text. Rename from 'gt_texts'. | |
- metainfo (dict): 'metainfo' is always populated. The contents of the | |
'metainfo' depends on ``meta_keys``. By default it includes: | |
- "img_path": Path to the image file. | |
- "ori_shape": Shape of the preprocessed image as a tuple | |
(h, w). | |
- "img_shape": Shape of the image input to the network as a tuple | |
(h, w). Note that the image may be zero-padded afterward on the | |
bottom/right if the batch tensor is larger than this shape. | |
- "valid_ratio": The proportion of valid (unpadded) content of image | |
on the x-axis. It defaults to 1 if not set in pipeline. | |
Args: | |
meta_keys (Sequence[str], optional): Meta keys to be converted to | |
the metainfo of ``TextRecogDataSampel``. Defaults to | |
``('img_path', 'ori_shape', 'img_shape', 'pad_shape', | |
'valid_ratio')``. | |
""" | |
def __init__(self, | |
meta_keys=('img_path', 'ori_shape', 'img_shape', 'pad_shape', | |
'valid_ratio')): | |
self.meta_keys = meta_keys | |
def transform(self, results: dict) -> dict: | |
"""Method to pack the input data. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Returns: | |
dict: | |
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding. | |
- 'data_samples' (obj:`TextRecogDataSample`): The annotation info | |
of the sample. | |
""" | |
packed_results = dict() | |
if 'img' in results: | |
img = results['img'] | |
if len(img.shape) < 3: | |
img = np.expand_dims(img, -1) | |
# A simple trick to speedup formatting by 3-5 times when | |
# OMP_NUM_THREADS != 1 | |
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533 | |
# for more details | |
if img.flags.c_contiguous: | |
img = to_tensor(img) | |
img = img.permute(2, 0, 1).contiguous() | |
else: | |
img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |
img = to_tensor(img) | |
packed_results['inputs'] = img | |
data_sample = TextRecogDataSample() | |
gt_text = LabelData() | |
if results.get('gt_texts', None): | |
assert len( | |
results['gt_texts'] | |
) == 1, 'Each image sample should have one text annotation only' | |
gt_text.item = results['gt_texts'][0] | |
data_sample.gt_text = gt_text | |
img_meta = {} | |
for key in self.meta_keys: | |
if key == 'valid_ratio': | |
img_meta[key] = results.get('valid_ratio', 1) | |
else: | |
img_meta[key] = results[key] | |
data_sample.set_metainfo(img_meta) | |
packed_results['data_samples'] = data_sample | |
return packed_results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(meta_keys={self.meta_keys})' | |
return repr_str | |
class PackKIEInputs(BaseTransform): | |
"""Pack the inputs data for key information extraction. | |
The type of outputs is `dict`: | |
- inputs: image converted to tensor, whose shape is (C, H, W). | |
- data_samples: Two components of ``TextDetDataSample`` will be updated: | |
- gt_instances (InstanceData): Depending on annotations, a subset of the | |
following keys will be updated: | |
- bboxes (torch.Tensor((N, 4), dtype=torch.float32)): The groundtruth | |
of bounding boxes in the form of [x1, y1, x2, y2]. Renamed from | |
'gt_bboxes'. | |
- labels (torch.LongTensor(N)): The labels of instances. | |
Renamed from 'gt_bboxes_labels'. | |
- edge_labels (torch.LongTensor(N, N)): The edge labels. | |
Renamed from 'gt_edges_labels'. | |
- texts (list[str]): The groundtruth texts. Renamed from 'gt_texts'. | |
- metainfo (dict): 'metainfo' is always populated. The contents of the | |
'metainfo' depends on ``meta_keys``. By default it includes: | |
- "img_path": Path to the image file. | |
- "img_shape": Shape of the image input to the network as a tuple | |
(h, w). Note that the image may be zero-padded afterward on the | |
bottom/right if the batch tensor is larger than this shape. | |
- "scale_factor": A tuple indicating the ratio of width and height | |
of the preprocessed image to the original one. | |
- "ori_shape": Shape of the preprocessed image as a tuple | |
(h, w). | |
Args: | |
meta_keys (Sequence[str], optional): Meta keys to be converted to | |
the metainfo of ``TextDetSample``. Defaults to ``('img_path', | |
'ori_shape', 'img_shape', 'scale_factor', 'flip', | |
'flip_direction')``. | |
""" | |
mapping_table = { | |
'gt_bboxes': 'bboxes', | |
'gt_bboxes_labels': 'labels', | |
'gt_edges_labels': 'edge_labels', | |
'gt_texts': 'texts', | |
} | |
def __init__(self, meta_keys=()): | |
self.meta_keys = meta_keys | |
def transform(self, results: dict) -> dict: | |
"""Method to pack the input data. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Returns: | |
dict: | |
- 'inputs' (obj:`torch.Tensor`): Data for model forwarding. | |
- 'data_samples' (obj:`DetDataSample`): The annotation info of the | |
sample. | |
""" | |
packed_results = dict() | |
if 'img' in results: | |
img = results['img'] | |
if len(img.shape) < 3: | |
img = np.expand_dims(img, -1) | |
# A simple trick to speedup formatting by 3-5 times when | |
# OMP_NUM_THREADS != 1 | |
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533 | |
# for more details | |
if img.flags.c_contiguous: | |
img = to_tensor(img) | |
img = img.permute(2, 0, 1).contiguous() | |
else: | |
img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |
img = to_tensor(img) | |
packed_results['inputs'] = img | |
else: | |
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0) | |
data_sample = KIEDataSample() | |
instance_data = InstanceData() | |
for key in self.mapping_table.keys(): | |
if key not in results: | |
continue | |
if key in ['gt_bboxes', 'gt_bboxes_labels', 'gt_edges_labels']: | |
instance_data[self.mapping_table[key]] = to_tensor( | |
results[key]) | |
else: | |
instance_data[self.mapping_table[key]] = results[key] | |
data_sample.gt_instances = instance_data | |
img_meta = {} | |
for key in self.meta_keys: | |
img_meta[key] = results[key] | |
data_sample.set_metainfo(img_meta) | |
packed_results['data_samples'] = data_sample | |
return packed_results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(meta_keys={self.meta_keys})' | |
return repr_str | |