File size: 2,516 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import unittest

import cv2
import numpy as np
from mmengine.structures import LabelData

from mmocr.structures import TextRecogDataSample
from mmocr.visualization import TextRecogLocalVisualizer


class TestTextDetLocalVisualizer(unittest.TestCase):

    def test_add_datasample(self):
        h, w = 64, 128
        image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')

        # test gt_text
        data_sample = TextRecogDataSample()
        img_meta = dict(img_shape=(12, 10, 3))
        gt_text = LabelData(metainfo=img_meta)
        gt_text.item = 'mmocr'
        data_sample.gt_text = gt_text

        recog_local_visualizer = TextRecogLocalVisualizer()
        recog_local_visualizer.add_datasample('image', image, data_sample)

        # test gt_text and pred_text
        pred_text = LabelData(metainfo=img_meta)
        pred_text.item = 'MMOCR'
        data_sample.pred_text = pred_text

        with tempfile.TemporaryDirectory() as tmp_dir:
            # test out
            out_file = osp.join(tmp_dir, 'out_file.jpg')

            # draw_gt = True + gt_sample
            recog_local_visualizer.add_datasample(
                'image',
                image,
                data_sample,
                out_file=out_file,
                draw_gt=True,
                draw_pred=False)
            self._assert_image_and_shape(out_file, (h * 2, w, 3))

            # draw_gt = True
            recog_local_visualizer.add_datasample(
                'image',
                image,
                data_sample,
                out_file=out_file,
                draw_gt=True,
                draw_pred=True)
            self._assert_image_and_shape(out_file, (h * 3, w, 3))

            # draw_gt = False
            recog_local_visualizer.add_datasample(
                'image', image, data_sample, draw_gt=False, out_file=out_file)
            self._assert_image_and_shape(out_file, (h * 2, w, 3))

            # gray image
            image = np.random.randint(0, 256, size=(h, w)).astype('uint8')
            recog_local_visualizer.add_datasample(
                'image', image, data_sample, draw_gt=False, out_file=out_file)
            self._assert_image_and_shape(out_file, (h * 2, w, 3))

    def _assert_image_and_shape(self, out_file, out_shape):
        self.assertTrue(osp.exists(out_file))
        drawn_img = cv2.imread(out_file)
        self.assertTrue(drawn_img.shape == out_shape)