| | |
| | import io |
| | import json |
| | import os |
| | import platform |
| | import random |
| | import sys |
| | import tempfile |
| | from pathlib import Path |
| | from unittest import mock |
| |
|
| | import mmcv |
| | import numpy as np |
| | import pytest |
| | import torch |
| |
|
| | from mmocr.apis import init_detector |
| | from mmocr.datasets.kie_dataset import KIEDataset |
| | from mmocr.utils.ocr import MMOCR |
| |
|
| |
|
| | def test_ocr_init_errors(): |
| | |
| | with pytest.raises(ValueError): |
| | _ = MMOCR(det='test') |
| | with pytest.raises(ValueError): |
| | _ = MMOCR(recog='test') |
| | with pytest.raises(ValueError): |
| | _ = MMOCR(kie='test') |
| | with pytest.raises(NotImplementedError): |
| | _ = MMOCR(det=None, recog=None, kie='SDMGR') |
| | with pytest.raises(NotImplementedError): |
| | _ = MMOCR(det='DB_r18', recog=None, kie='SDMGR') |
| |
|
| |
|
| | cfg_default_prefix = os.path.join(str(Path.cwd()), 'configs/') |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | 'det, recog, kie, config_dir, gt_cfg, gt_ckpt', |
| | [('DB_r18', None, '', '', |
| | cfg_default_prefix + 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| | 'https://download.openmmlab.com/mmocr/textdet/' |
| | 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'), |
| | (None, 'CRNN', '', '', |
| | cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py', |
| | 'https://download.openmmlab.com/mmocr/textrecog/' |
| | 'crnn/crnn_academic-a723a1c5.pth'), |
| | ('DB_r18', 'CRNN', 'SDMGR', '', [ |
| | cfg_default_prefix + |
| | 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| | cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py', |
| | cfg_default_prefix + 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| | ], [ |
| | 'https://download.openmmlab.com/mmocr/textdet/' |
| | 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| | 'https://download.openmmlab.com/mmocr/textrecog/' |
| | 'crnn/crnn_academic-a723a1c5.pth', |
| | 'https://download.openmmlab.com/mmocr/kie/' |
| | 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| | ]), |
| | ('DB_r18', 'CRNN', 'SDMGR', 'test/', [ |
| | 'test/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| | 'test/textrecog/crnn/crnn_academic_dataset.py', |
| | 'test/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| | ], [ |
| | 'https://download.openmmlab.com/mmocr/textdet/' |
| | 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| | 'https://download.openmmlab.com/mmocr/textrecog/' |
| | 'crnn/crnn_academic-a723a1c5.pth', |
| | 'https://download.openmmlab.com/mmocr/kie/' |
| | 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| | ])], |
| | ) |
| | @mock.patch('mmocr.utils.ocr.init_detector') |
| | @mock.patch('mmocr.utils.ocr.build_detector') |
| | @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| | @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| | def test_ocr_init(mock_loading, mock_config, mock_build_detector, |
| | mock_init_detector, det, recog, kie, config_dir, gt_cfg, |
| | gt_ckpt): |
| |
|
| | def loadcheckpoint_assert(*args, **kwargs): |
| | assert args[1] == gt_ckpt[-1] |
| | assert kwargs['map_location'] == torch.device( |
| | 'cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | mock_loading.side_effect = loadcheckpoint_assert |
| | with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): |
| | if kie == '': |
| | if config_dir == '': |
| | _ = MMOCR(det=det, recog=recog) |
| | else: |
| | _ = MMOCR(det=det, recog=recog, config_dir=config_dir) |
| | else: |
| | if config_dir == '': |
| | _ = MMOCR(det=det, recog=recog, kie=kie) |
| | else: |
| | _ = MMOCR(det=det, recog=recog, kie=kie, config_dir=config_dir) |
| | if isinstance(gt_cfg, str): |
| | gt_cfg = [gt_cfg] |
| | if isinstance(gt_ckpt, str): |
| | gt_ckpt = [gt_ckpt] |
| |
|
| | i_range = range(len(gt_cfg)) |
| | if kie: |
| | i_range = i_range[:-1] |
| | mock_config.assert_called_with(gt_cfg[-1]) |
| | mock_build_detector.assert_called_once() |
| | mock_loading.assert_called_once() |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | calls = [ |
| | mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range |
| | ] |
| | mock_init_detector.assert_has_calls(calls) |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | 'det, det_config, det_ckpt, recog, recog_config, recog_ckpt,' |
| | 'kie, kie_config, kie_ckpt, config_dir, gt_cfg, gt_ckpt', |
| | [('DB_r18', 'test.py', '', 'CRNN', 'test.py', '', 'SDMGR', 'test.py', '', |
| | 'configs/', ['test.py', 'test.py', 'test.py'], [ |
| | 'https://download.openmmlab.com/mmocr/textdet/' |
| | 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| | 'https://download.openmmlab.com/mmocr/textrecog/' |
| | 'crnn/crnn_academic-a723a1c5.pth', |
| | 'https://download.openmmlab.com/mmocr/kie/' |
| | 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| | ]), |
| | ('DB_r18', '', 'test.ckpt', 'CRNN', '', 'test.ckpt', 'SDMGR', '', |
| | 'test.ckpt', '', [ |
| | 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| | 'textrecog/crnn/crnn_academic_dataset.py', |
| | 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| | ], ['test.ckpt', 'test.ckpt', 'test.ckpt']), |
| | ('DB_r18', 'test.py', 'test.ckpt', 'CRNN', 'test.py', 'test.ckpt', |
| | 'SDMGR', 'test.py', 'test.ckpt', '', ['test.py', 'test.py', 'test.py'], |
| | ['test.ckpt', 'test.ckpt', 'test.ckpt'])]) |
| | @mock.patch('mmocr.utils.ocr.init_detector') |
| | @mock.patch('mmocr.utils.ocr.build_detector') |
| | @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| | @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| | def test_ocr_init_customize_config(mock_loading, mock_config, |
| | mock_build_detector, mock_init_detector, |
| | det, det_config, det_ckpt, recog, |
| | recog_config, recog_ckpt, kie, kie_config, |
| | kie_ckpt, config_dir, gt_cfg, gt_ckpt): |
| |
|
| | def loadcheckpoint_assert(*args, **kwargs): |
| | assert args[1] == gt_ckpt[-1] |
| |
|
| | mock_loading.side_effect = loadcheckpoint_assert |
| | with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): |
| | _ = MMOCR( |
| | det=det, |
| | det_config=det_config, |
| | det_ckpt=det_ckpt, |
| | recog=recog, |
| | recog_config=recog_config, |
| | recog_ckpt=recog_ckpt, |
| | kie=kie, |
| | kie_config=kie_config, |
| | kie_ckpt=kie_ckpt, |
| | config_dir=config_dir) |
| |
|
| | i_range = range(len(gt_cfg)) |
| | if kie: |
| | i_range = i_range[:-1] |
| | mock_config.assert_called_with(gt_cfg[-1]) |
| | mock_build_detector.assert_called_once() |
| | mock_loading.assert_called_once() |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | calls = [ |
| | mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range |
| | ] |
| | mock_init_detector.assert_has_calls(calls) |
| |
|
| |
|
| | @mock.patch('mmocr.utils.ocr.init_detector') |
| | @mock.patch('mmocr.utils.ocr.build_detector') |
| | @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| | @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| | @mock.patch('mmocr.utils.ocr.model_inference') |
| | def test_single_inference(mock_model_inference, mock_loading, mock_config, |
| | mock_build_detector, mock_init_detector): |
| |
|
| | def dummy_inference(model, arr, batch_mode): |
| | return arr |
| |
|
| | mock_model_inference.side_effect = dummy_inference |
| | mmocr = MMOCR() |
| |
|
| | data = list(range(20)) |
| | model = 'dummy' |
| | res = mmocr.single_inference(model, data, batch_mode=False) |
| | assert (data == res) |
| | mock_model_inference.reset_mock() |
| |
|
| | res = mmocr.single_inference(model, data, batch_mode=True) |
| | assert (data == res) |
| | mock_model_inference.assert_called_once() |
| | mock_model_inference.reset_mock() |
| |
|
| | res = mmocr.single_inference(model, data, batch_mode=True, batch_size=100) |
| | assert (data == res) |
| | mock_model_inference.assert_called_once() |
| | mock_model_inference.reset_mock() |
| |
|
| | res = mmocr.single_inference(model, data, batch_mode=True, batch_size=3) |
| | assert (data == res) |
| |
|
| |
|
| | @mock.patch('mmocr.utils.ocr.init_detector') |
| | @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| | def MMOCR_testobj(mock_loading, mock_init_detector, **kwargs): |
| | |
| | |
| | def init_detector_skip_ckpt(config, ckpt, device): |
| | return init_detector(config, device=device) |
| |
|
| | def modify_kie_class(model, ckpt, map_location): |
| | model.class_list = 'tests/data/kie_toy_dataset/class_list.txt' |
| |
|
| | mock_init_detector.side_effect = init_detector_skip_ckpt |
| | mock_loading.side_effect = modify_kie_class |
| | kwargs['det'] = kwargs.get('det', 'DB_r18') |
| | kwargs['recog'] = kwargs.get('recog', 'CRNN') |
| | kwargs['kie'] = kwargs.get('kie', 'SDMGR') |
| | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| | return MMOCR(**kwargs, device=device) |
| |
|
| |
|
| | @pytest.mark.skipif( |
| | platform.system() == 'Windows', |
| | reason='Win container on Github Action does not have enough RAM to run') |
| | @mock.patch('mmocr.utils.ocr.KIEDataset') |
| | def test_readtext(mock_kiedataset): |
| | |
| | |
| | torch.manual_seed(4) |
| | random.seed(4) |
| | mmocr = MMOCR_testobj() |
| | mmocr_det = MMOCR_testobj(kie='', recog='') |
| | mmocr_recog = MMOCR_testobj(kie='', det='', recog='CRNN_TPS') |
| | mmocr_det_recog = MMOCR_testobj(kie='') |
| |
|
| | def readtext(imgs, ocr_obj=mmocr, **kwargs): |
| | |
| | |
| | e2e_res = ocr_obj.readtext(imgs, **kwargs) |
| | for res in e2e_res: |
| | res.pop('filename') |
| | return e2e_res |
| |
|
| | def kiedataset_with_test_dict(**kwargs): |
| | kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt' |
| | return KIEDataset(**kwargs) |
| |
|
| | mock_kiedataset.side_effect = kiedataset_with_test_dict |
| |
|
| | |
| | toy_dir = 'tests/data/toy_dataset/imgs/test/' |
| | toy_img1_path = toy_dir + 'img_1.jpg' |
| | str_e2e_res = readtext(toy_img1_path) |
| | toy_img1 = mmcv.imread(toy_img1_path) |
| | np_e2e_res = readtext(toy_img1) |
| | assert str_e2e_res == np_e2e_res |
| |
|
| | |
| | toy_img2_path = toy_dir + 'img_2.jpg' |
| | toy_img2 = mmcv.imread(toy_img2_path) |
| | toy_imgs = [toy_img1, toy_img2] |
| | toy_img_paths = [toy_img1_path, toy_img2_path] |
| | np_e2e_results = readtext(toy_imgs) |
| | str_e2e_results = readtext(toy_img_paths) |
| | str_tuple_e2e_results = readtext(tuple(toy_img_paths)) |
| | assert np_e2e_results == str_e2e_results |
| | assert str_e2e_results == str_tuple_e2e_results |
| |
|
| | |
| | toy_imgs.append(toy_dir + 'img_3.jpg') |
| | e2e_res = readtext(toy_imgs) |
| | full_batch_e2e_res = readtext(toy_imgs, batch_mode=True) |
| | assert full_batch_e2e_res == e2e_res |
| | batch_e2e_res = readtext( |
| | toy_imgs, batch_mode=True, recog_batch_size=2, det_batch_size=2) |
| | assert batch_e2e_res == full_batch_e2e_res |
| |
|
| | |
| | full_batch_det_res = mmocr_det.readtext(toy_imgs, batch_mode=True) |
| | det_res = mmocr_det.readtext(toy_imgs) |
| | batch_det_res = mmocr_det.readtext( |
| | toy_imgs, batch_mode=True, single_batch_size=2) |
| | assert len(full_batch_det_res) == len(det_res) |
| | assert len(batch_det_res) == len(det_res) |
| | assert all([ |
| | np.allclose(full_batch_det_res[i]['boundary_result'], |
| | det_res[i]['boundary_result']) |
| | for i in range(len(full_batch_det_res)) |
| | ]) |
| | assert all([ |
| | np.allclose(batch_det_res[i]['boundary_result'], |
| | det_res[i]['boundary_result']) |
| | for i in range(len(batch_det_res)) |
| | ]) |
| |
|
| | |
| | full_batch_recog_res = mmocr_recog.readtext(toy_imgs, batch_mode=True) |
| | recog_res = mmocr_recog.readtext(toy_imgs) |
| | batch_recog_res = mmocr_recog.readtext( |
| | toy_imgs, batch_mode=True, single_batch_size=2) |
| | full_batch_recog_res.sort(key=lambda x: x['text']) |
| | batch_recog_res.sort(key=lambda x: x['text']) |
| | recog_res.sort(key=lambda x: x['text']) |
| | assert np.all([ |
| | np.allclose(full_batch_recog_res[i]['score'], recog_res[i]['score']) |
| | for i in range(len(full_batch_recog_res)) |
| | ]) |
| | assert np.all([ |
| | np.allclose(batch_recog_res[i]['score'], recog_res[i]['score']) |
| | for i in range(len(full_batch_recog_res)) |
| | ]) |
| |
|
| | |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | mmocr.readtext(toy_imgs, export=tmpdirname) |
| | assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | mmocr_det.readtext(toy_imgs, export=tmpdirname) |
| | assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | mmocr_recog.readtext(toy_imgs, export=tmpdirname) |
| | assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| |
|
| | |
| | |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | tmp_output = os.path.join(tmpdirname, '1.jpg') |
| | mmocr.readtext(toy_imgs[0], output=tmp_output) |
| | assert os.path.exists(tmp_output) |
| | |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | mmocr.readtext(toy_imgs, output=tmpdirname) |
| | assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| |
|
| | |
| | with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow: |
| | mmocr.readtext(toy_img1_path, imshow=True) |
| | mock_imshow.assert_called_once() |
| | mock_imshow.reset_mock() |
| | mmocr.readtext(toy_imgs, imshow=True) |
| | assert mock_imshow.call_count == len(toy_imgs) |
| |
|
| | |
| | with io.StringIO() as capturedOutput: |
| | sys.stdout = capturedOutput |
| | res = mmocr.readtext(toy_imgs, print_result=True) |
| | assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( |
| | '\n\n', ',').replace("'", '"')) == res |
| | sys.stdout = sys.__stdout__ |
| | with io.StringIO() as capturedOutput: |
| | sys.stdout = capturedOutput |
| | res = mmocr.readtext(toy_imgs, details=True, print_result=True) |
| | assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( |
| | '\n\n', ',').replace("'", '"')) == res |
| | sys.stdout = sys.__stdout__ |
| |
|
| | |
| | with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge: |
| | mmocr_det_recog.readtext(toy_imgs, merge=True) |
| | assert mock_merge.call_count == len(toy_imgs) |
| |
|