| | |
| | import os.path as osp |
| | import tempfile |
| | from functools import partial |
| |
|
| | import numpy as np |
| | import pytest |
| | import torch |
| | from mmdet.core import BitmapMasks |
| |
|
| | from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer, |
| | SegRecognizer) |
| |
|
| |
|
| | def _create_dummy_dict_file(dict_file): |
| | chars = list('helowrd') |
| | with open(dict_file, 'w') as fw: |
| | for char in chars: |
| | fw.write(char + '\n') |
| |
|
| |
|
| | def test_base_recognizer(): |
| | tmp_dir = tempfile.TemporaryDirectory() |
| | |
| | dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') |
| | _create_dummy_dict_file(dict_file) |
| |
|
| | label_convertor = dict( |
| | type='CTCConvertor', dict_file=dict_file, with_unknown=False) |
| |
|
| | preprocessor = None |
| | backbone = dict(type='VeryDeepVgg', leaky_relu=False) |
| | encoder = None |
| | decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True) |
| | loss = dict(type='CTCLoss') |
| |
|
| | with pytest.raises(AssertionError): |
| | EncodeDecodeRecognizer(backbone=None) |
| | with pytest.raises(AssertionError): |
| | EncodeDecodeRecognizer(decoder=None) |
| | with pytest.raises(AssertionError): |
| | EncodeDecodeRecognizer(loss=None) |
| | with pytest.raises(AssertionError): |
| | EncodeDecodeRecognizer(label_convertor=None) |
| |
|
| | recognizer = EncodeDecodeRecognizer( |
| | preprocessor=preprocessor, |
| | backbone=backbone, |
| | encoder=encoder, |
| | decoder=decoder, |
| | loss=loss, |
| | label_convertor=label_convertor) |
| |
|
| | recognizer.init_weights() |
| | recognizer.train() |
| |
|
| | imgs = torch.rand(1, 3, 32, 160) |
| |
|
| | |
| | feat = recognizer.extract_feat(imgs) |
| | assert feat.shape == torch.Size([1, 512, 1, 41]) |
| |
|
| | |
| | img_metas = [{ |
| | 'text': 'hello', |
| | 'resize_shape': (32, 120, 3), |
| | 'valid_ratio': 1.0 |
| | }] |
| | losses = recognizer.forward_train(imgs, img_metas) |
| | assert isinstance(losses, dict) |
| | assert 'loss_ctc' in losses |
| |
|
| | |
| | results = recognizer.simple_test(imgs, img_metas) |
| | assert isinstance(results, list) |
| | assert isinstance(results[0], dict) |
| | assert 'text' in results[0] |
| | assert 'score' in results[0] |
| |
|
| | |
| | recognizer.forward = partial( |
| | recognizer.simple_test, |
| | img_metas=img_metas, |
| | return_loss=False, |
| | rescale=True) |
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | onnx_path = f'{tmpdirname}/tmp.onnx' |
| | torch.onnx.export( |
| | recognizer, (imgs, ), |
| | onnx_path, |
| | input_names=['input'], |
| | output_names=['output'], |
| | export_params=True, |
| | keep_initializers_as_inputs=False) |
| |
|
| | |
| | aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) |
| | assert isinstance(aug_results, list) |
| | assert isinstance(aug_results[0], dict) |
| | assert 'text' in aug_results[0] |
| | assert 'score' in aug_results[0] |
| |
|
| | tmp_dir.cleanup() |
| |
|
| |
|
| | def test_seg_recognizer(): |
| | tmp_dir = tempfile.TemporaryDirectory() |
| | |
| | dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') |
| | _create_dummy_dict_file(dict_file) |
| |
|
| | label_convertor = dict( |
| | type='SegConvertor', dict_file=dict_file, with_unknown=False) |
| |
|
| | preprocessor = None |
| | backbone = dict( |
| | type='ResNet31OCR', |
| | layers=[1, 2, 5, 3], |
| | channels=[32, 64, 128, 256, 512, 512], |
| | out_indices=[0, 1, 2, 3], |
| | stage4_pool_cfg=dict(kernel_size=2, stride=2), |
| | last_stage_pool=True) |
| | neck = dict( |
| | type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256) |
| | head = dict( |
| | type='SegHead', |
| | in_channels=256, |
| | upsample_param=dict(scale_factor=2.0, mode='nearest')) |
| | loss = dict(type='SegLoss', seg_downsample_ratio=1.0) |
| |
|
| | with pytest.raises(AssertionError): |
| | SegRecognizer(backbone=None) |
| | with pytest.raises(AssertionError): |
| | SegRecognizer(neck=None) |
| | with pytest.raises(AssertionError): |
| | SegRecognizer(head=None) |
| | with pytest.raises(AssertionError): |
| | SegRecognizer(loss=None) |
| | with pytest.raises(AssertionError): |
| | SegRecognizer(label_convertor=None) |
| |
|
| | recognizer = SegRecognizer( |
| | preprocessor=preprocessor, |
| | backbone=backbone, |
| | neck=neck, |
| | head=head, |
| | loss=loss, |
| | label_convertor=label_convertor) |
| |
|
| | recognizer.init_weights() |
| | recognizer.train() |
| |
|
| | imgs = torch.rand(1, 3, 64, 256) |
| |
|
| | |
| | feats = recognizer.extract_feat(imgs) |
| | assert len(feats) == 4 |
| |
|
| | assert feats[0].shape == torch.Size([1, 128, 32, 128]) |
| | assert feats[1].shape == torch.Size([1, 256, 16, 64]) |
| | assert feats[2].shape == torch.Size([1, 512, 8, 32]) |
| | assert feats[3].shape == torch.Size([1, 512, 4, 16]) |
| |
|
| | attn_tgt = np.zeros((64, 256), dtype=np.float32) |
| | segm_tgt = np.zeros((64, 256), dtype=np.float32) |
| | mask = np.zeros((64, 256), dtype=np.float32) |
| | gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], 64, 256) |
| |
|
| | |
| | img_metas = [{ |
| | 'text': 'hello', |
| | 'resize_shape': (64, 256, 3), |
| | 'valid_ratio': 1.0 |
| | }] |
| | losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels]) |
| | assert isinstance(losses, dict) |
| |
|
| | |
| | results = recognizer.simple_test(imgs, img_metas) |
| | assert isinstance(results, list) |
| | assert isinstance(results[0], dict) |
| | assert 'text' in results[0] |
| | assert 'score' in results[0] |
| |
|
| | |
| | aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) |
| | assert isinstance(aug_results, list) |
| | assert isinstance(aug_results[0], dict) |
| | assert 'text' in aug_results[0] |
| | assert 'score' in aug_results[0] |
| |
|
| | tmp_dir.cleanup() |
| |
|