Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import tempfile | |
import numpy as np | |
import pytest | |
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets | |
def _create_dummy_dict_file(dict_file): | |
chars = list('0123456789') | |
with open(dict_file, 'w') as fw: | |
for char in chars: | |
fw.write(char + '\n') | |
def test_ocr_segm_targets(): | |
tmp_dir = tempfile.TemporaryDirectory() | |
# create dummy dict file | |
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') | |
_create_dummy_dict_file(dict_file) | |
# dummy label convertor | |
label_convertor = dict( | |
type='SegConvertor', | |
dict_file=dict_file, | |
with_unknown=True, | |
lower=True) | |
# test init | |
with pytest.raises(AssertionError): | |
OCRSegTargets(None, 0.5, 0.5) | |
with pytest.raises(AssertionError): | |
OCRSegTargets(label_convertor, '1by2', 0.5) | |
with pytest.raises(AssertionError): | |
OCRSegTargets(label_convertor, 0.5, 2) | |
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5) | |
# test generate kernels | |
img_size = (8, 8) | |
pad_size = (8, 10) | |
char_boxes = [[2, 2, 6, 6]] | |
char_idxs = [2] | |
with pytest.raises(AssertionError): | |
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5, | |
True) | |
with pytest.raises(AssertionError): | |
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6], | |
char_idxs, 0.5, True) | |
with pytest.raises(AssertionError): | |
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5, | |
True) | |
attn_tgt = ocr_seg_tgt.generate_kernels( | |
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True) | |
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] | |
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32)) | |
segm_tgt = ocr_seg_tgt.generate_kernels( | |
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False) | |
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] | |
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32)) | |
# test __call__ | |
results = {} | |
results['img_shape'] = (4, 4, 3) | |
results['resize_shape'] = (8, 8, 3) | |
results['pad_shape'] = (8, 10) | |
results['ann_info'] = {} | |
results['ann_info']['char_rects'] = [[1, 1, 3, 3]] | |
results['ann_info']['chars'] = ['1'] | |
results = ocr_seg_tgt(results) | |
assert results['mask_fields'] == ['gt_kernels'] | |
assert np.allclose(results['gt_kernels'].masks[0], | |
np.array(expect_attn_tgt, dtype=np.int32)) | |
assert np.allclose(results['gt_kernels'].masks[1], | |
np.array(expect_segm_tgt, dtype=np.int32)) | |
tmp_dir.cleanup() | |