File size: 5,677 Bytes
174ad5e
bf6e441
100bb76
bf6e441
c0ba2f4
bf6e441
53e582c
d2a6d56
83d07db
585f861
 
 
 
 
 
 
 
 
174ad5e
 
 
 
 
 
 
6838280
174ad5e
 
 
 
 
79d74f9
174ad5e
 
 
 
 
6838280
174ad5e
 
 
 
79d74f9
174ad5e
 
 
 
96d7f77
174ad5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d93f21
 
 
 
174ad5e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

# Install dependencies
import os
os.system('python -m mim install mmocr')
os.system('pip install gradio_client==0.2.7')
os.system('python -m mim install "mmcv==2.0.0rc4"')
os.system('python -m mim install mmengine==0.7.1')
os.system('python -m mim install "mmdet==3.0.0rc5"')
os.system('pip install -v -e .')
import cv2
import argparse
import gradio as gr
import numpy as np

# MMOCR
from mmocr.apis.inferencers import MMOCRInferencer




def arg_parse():
    parser = argparse.ArgumentParser(description='MMOCR demo for gradio app')
    parser.add_argument(
        '--rec_config',
        type=str,
        default='configs/textrecog/maerec/maerec_b_union14m.py',
        help='The recognition config file.')
    parser.add_argument(
        '--rec_weight',
        type=str,
        default=
        'maerec_b_union14m.pth',
        help='The recognition weight file.')
    parser.add_argument(
        '--det_config',
        type=str,
        default=
        'configs/textdet/dbnetpp/dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015.py',  # noqa,
        help='The detection config file.')
    parser.add_argument(
        '--det_weight',
        type=str,
        default='dbnetpp.pth',
        help='The detection weight file.')
    parser.add_argument(
        '--device',
        type=str,
        default='cpu',
        help='The device used for inference.')
    args = parser.parse_args()
    return args


def run_mmocr(img: np.ndarray, use_detector: bool = True):
    """Run MMOCR and SAM

    Args:
        img (np.ndarray): Input image
        use_detector (bool, optional): Whether to use detector. Defaults to
            True.
    """
    if use_detector:
        mode = 'det_rec'
    else:
        mode = 'rec'
    # Build MMOCR
    mmocr_inferencer.mode = mode
    result = mmocr_inferencer(img, return_vis=True)
    visualization = result['visualization'][0]
    result = result['predictions'][0]

    if mode == 'det_rec':
        rec_texts = result['rec_texts']
        det_polygons = result['det_polygons']
        det_results = []
        for rec_text, det_polygon in zip(rec_texts, det_polygons):
            det_polygon = np.array(det_polygon).astype(np.int32).tolist()
            det_results.append(f'{rec_text}: {det_polygon}')
        out_results = '\n'.join(det_results)
        visualization = cv2.cvtColor(
            np.array(visualization), cv2.COLOR_RGB2BGR)
    else:
        rec_text = result['rec_texts'][0]
        rec_score = result['rec_scores'][0]
        out_results = f'pred: {rec_text} \n score: {rec_score:.2f}'
        visualization = None
    return visualization, out_results


if __name__ == '__main__':
    args = arg_parse()
    mmocr_inferencer = MMOCRInferencer(
        args.det_config,
        args.det_weight,
        args.rec_config,
        args.rec_weight,
        device=args.device)

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=1):
                gr.HTML("""
                    <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
                    <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
                        MAERec: A MAE-pretrained Scene Text Recognizer
                    </h1>
                    <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem"> 
                    [<a href="https://github.com/Mountchicken/Union14M" style="color:green;">Code</a>]
                    </h3>
                    <h2 style="text-align: left; font-weight: 600; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
                    MAERec is a scene text recognition model composed of a ViT backbone and a Transformer decoder in auto-regressive
                    style. It shows an outstanding performance in scene text recognition, especially when pre-trained on the
                    Union14M-U through MAE.
                    </h2>
                    <h2 style="text-align: left; font-weight: 600; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
                    In this demo, we combine MAERec with DBNet++ to build an
                    end-to-end scene text recognition model.
                    </h2>
                    </div>
                    """)
                gr.Image('github/maerec.png')
            with gr.Column(scale=1):
                input_image = gr.Image(label='Input Image')
                output_image = gr.Image(label='Output Image')
                use_detector = gr.Checkbox(
                    label=
                    'Use Scene Text Detector or Not (Disabled for Recognition Only)',
                    default=True)
                det_results = gr.Textbox(label='Detection Results')
                mmocr = gr.Button('Run MMOCR')
                gr.Markdown("## Image Examples")
        with gr.Row():
            gr.Examples(
                examples=[
                    'github/author.jpg','github/gradio1.jpeg','github/add1.jpg','github/add2.jpg','github/add3.jpg',
                    'github/Art_Curve_178.jpg','github/add4.jpg','github/add5.jpg','github/add6.jpg',
                    'github/add7.jpg','github/add8.jpg','github/add9.jpg','github/add10.jpg','github/add11.jpg',
                    'github/add12.jpg',
                    'github/cute_168.jpg', 'github/hiercurve_2229.jpg',
                    'github/ic15_52.jpg', 'github/ic15_698.jpg',
                    'github/Art_Curve_352.jpg'
                ],
                inputs=input_image,
            )
        mmocr.click(
            fn=run_mmocr,
            inputs=[input_image, use_detector],
            outputs=[output_image, det_results])
    demo.launch(debug=True)