Spaces:
Runtime error
Runtime error
File size: 4,169 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import pytest
from mmcv import Config
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
])
def test_disable_text_recog_aug_test(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
cfg = Config.fromfile(config_file)
test = cfg.data.test.datasets[0]
# cfg.data.test.type is 'OCRDataset'
cfg1 = copy.deepcopy(cfg)
test1 = copy.deepcopy(test)
test1.pipeline = cfg1.data.test.pipeline
cfg1.data.test = test1
cfg1 = disable_text_recog_aug_test(cfg1, set_types=['test'])
assert cfg1.data.test.pipeline[1].type != 'MultiRotateAugOCR'
# cfg.data.test.type is 'UniformConcatDataset'
# and cfg.data.test.pipeline is list[dict]
cfg2 = copy.deepcopy(cfg)
test2 = copy.deepcopy(test)
test2.pipeline = cfg2.data.test.pipeline
cfg2.data.test.datasets = [test2]
cfg2 = disable_text_recog_aug_test(cfg2, set_types=['test'])
assert cfg2.data.test.pipeline[1].type != 'MultiRotateAugOCR'
assert cfg2.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'
# cfg.data.test.type is 'ConcatDataset'
cfg3 = copy.deepcopy(cfg)
test3 = copy.deepcopy(test)
test3.pipeline = cfg3.data.test.pipeline
cfg3.data.test = Config(dict(type='ConcatDataset', datasets=[test3]))
cfg3 = disable_text_recog_aug_test(cfg3, set_types=['test'])
assert cfg3.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'
# cfg.data.test.type is 'UniformConcatDataset'
# and cfg.data.test.pipeline is list[list[dict]]
cfg4 = copy.deepcopy(cfg)
test4 = copy.deepcopy(test)
test4.pipeline = cfg4.data.test.pipeline
cfg4.data.test.datasets = [[test4], [test]]
cfg4.data.test.pipeline = [
cfg4.data.test.pipeline, cfg4.data.test.pipeline
]
cfg4 = disable_text_recog_aug_test(cfg4, set_types=['test'])
assert cfg4.data.test.datasets[0][0].pipeline[1].type != \
'MultiRotateAugOCR'
# cfg.data.test.type is 'UniformConcatDataset'
# and cfg.data.test.pipeline is None
cfg5 = copy.deepcopy(cfg)
test5 = copy.deepcopy(test)
test5.pipeline = copy.deepcopy(cfg5.data.test.pipeline)
cfg5.data.test.datasets = [test5]
cfg5.data.test.pipeline = None
cfg5 = disable_text_recog_aug_test(cfg5, set_types=['test'])
assert cfg5.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'
@pytest.mark.parametrize('cfg_file', [
'../configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py',
])
def test_replace_image_to_tensor(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
cfg = Config.fromfile(config_file)
test = cfg.data.test.datasets[0]
# cfg.data.test.pipeline is list[dict]
# and cfg.data.test.datasets is list[dict]
cfg1 = copy.deepcopy(cfg)
test1 = copy.deepcopy(test)
test1.pipeline = copy.deepcopy(cfg.data.test.pipeline)
cfg1.data.test.datasets = [test1]
cfg1 = replace_image_to_tensor(cfg1, set_types=['test'])
assert cfg1.data.test.pipeline[1]['transforms'][3][
'type'] == 'DefaultFormatBundle'
assert cfg1.data.test.datasets[0].pipeline[1]['transforms'][3][
'type'] == 'DefaultFormatBundle'
# cfg.data.test.pipeline is list[list[dict]]
# and cfg.data.test.datasets is list[list[dict]]
cfg2 = copy.deepcopy(cfg)
test2 = copy.deepcopy(test)
test2.pipeline = copy.deepcopy(cfg.data.test.pipeline)
cfg2.data.test.datasets = [[test2], [test2]]
cfg2.data.test.pipeline = [
cfg2.data.test.pipeline, cfg2.data.test.pipeline
]
cfg2 = replace_image_to_tensor(cfg2, set_types=['test'])
assert cfg2.data.test.pipeline[0][1]['transforms'][3][
'type'] == 'DefaultFormatBundle'
assert cfg2.data.test.datasets[0][0].pipeline[1]['transforms'][3][
'type'] == 'DefaultFormatBundle'
|