File size: 3,664 Bytes
9633a4f
98a9c54
9633a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7de557
9633a4f
 
 
 
 
 
 
 
5533d5d
e61a121
 
9633a4f
 
e61a121
9633a4f
 
 
 
 
 
 
 
 
 
 
 
 
f7de557
9633a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7de557
9633a4f
 
 
 
 
 
 
f7de557
e61a121
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.system('pip install gradio==3.9.1')
import torch

print(torch.__version__)
# torch_ver, cuda_ver = torch.__version__.split('+')
# os.system('pip list')
# os.system(f'pip install opencv-contrib-python==4.5.5.62 --no-cache-dir')
# os.system('pip list')
# os.system(f'pip install pycocotools==2.0.0 mmdet mmcv-full==1.5.0 -f https://download.openmmlab.com/mmcv/dist/{cuda_ver}/torch1.10.0/index.html --no-cache-dir')
# os.system('wget -nv -c https://download.openmmlab.com/mmocr/data/wildreceipt.tar; mkdir -p data; tar -xf wildreceipt.tar --directory data; rm -f wildreceipt.tar')

import datetime
import gradio as gr
import pandas as pd
from mmocr.utils.ocr import MMOCR
import os

def inference(img, det, recog):
    print(datetime.datetime.now(), 'start')
    # ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
    # ocr = MMOCR(det=det, recog=recog, kie='SDMGR')
    if det == 'None':
        det = None
    if recog == 'None':
        recog = None
    ocr = MMOCR(det=det, recog=recog)
    print(datetime.datetime.now(), 'start read:', img.name)
    results = ocr.readtext(img.name, details=True, output='/tmp')
    result_file = '/tmp/out_{}.png'.format(os.path.splitext(os.path.basename(img.name))[0])
    print(datetime.datetime.now(), results)
    # return result_file, pd.DataFrame(results[0]['result']).iloc[: , 2:]
    return result_file, results

description = 'Gradio demo for MMOCR. MMOCR is an open-source toolbox based on PyTorch and mmdetection for text detection, text recognition, and the corresponding downstream tasks including key information extraction. To use it, simply upload your image or click one of the examples to load them. Read more at the links below.'
article = "<p style='text-align: center'><a href='https://mmocr.readthedocs.io/en/latest/'>MMOCR is an open-source toolbox based on PyTorch and mmdetection for text detection, text recognition, and the corresponding downstream tasks including key information extraction.</a> | <a href='https://github.com/open-mmlab/mmocr'>Github Repo</a></p>"


examples = []
path = './images'

files = os.listdir(path)
files.sort()
for f in files:
    file = os.path.join(path, f)
    if os.path.isfile(file):
        examples.append([file, 'PS_CTW', 'SAR'])

det = gr.inputs.Dropdown(choices=[
            'DB_r18',
            'DB_r50',
            'DBPP_r50',
            'DRRG',
            'FCE_IC15',
            'FCE_CTW_DCNv2',
            'MaskRCNN_CTW',
            'MaskRCNN_IC15',
            'MaskRCNN_IC17',
            'PANet_CTW',
            'PANet_IC15',
            'PS_CTW',
            'PS_IC15',
            'TextSnake',
            'None'
], type="value", default='PS_CTW', label='det')

recog = gr.inputs.Dropdown(choices=[
            'CRNN',
            'SAR',
            'SAR_CN',
            'NRTR_1/16-1/8',
            'NRTR_1/8-1/4',
            'RobustScanner',
            'SATRN',
            'SATRN_sm',
            'ABINet',
            'ABINet_Vision',
            'SEG',
            'CRNN_TPS',
            'MASTER',
            'None'
], type="value", default='SAR', label='recog')

gr.Interface(inference,
    [gr.inputs.Image(type='file', label='Input'), det, recog ],
    # [gr.outputs.Image(type='pil', label='Output'), gr.outputs.Dataframe(headers=['text', 'text_score', 'label', 'label_score'])],
    [gr.outputs.Image(type='pil', label='Output'), gr.outputs.Textbox(type='str', label='Prediction')],
    title='MMOCR',
    description=description,
    article=article,
    examples=examples,
    css=".output_image, .input_image {height: 40rem !important; width: 100% !important;}",
    enable_queue=True
    ).launch(debug=True)