breezedeus commited on
Commit
0464f00
1 Parent(s): 4f2cc1d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ import os
21
+ from collections import OrderedDict
22
+
23
+ import cv2
24
+ import numpy as np
25
+ from PIL import Image
26
+ import streamlit as st
27
+ from cnstd.utils import pil_to_numpy, imsave
28
+
29
+ from cnocr import CnOcr, DET_AVAILABLE_MODELS, REC_AVAILABLE_MODELS
30
+ from cnocr.utils import set_logger, draw_ocr_results, download
31
+
32
+
33
+ logger = set_logger()
34
+ st.set_page_config(layout="wide")
35
+
36
+
37
+ def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefix_fp):
38
+ import matplotlib.pyplot as plt
39
+ import math
40
+
41
+ rotated_img = rotated_img.copy()
42
+ crops = [info['cropped_img'] for info in one_out]
43
+ print('%d boxes are found' % len(crops))
44
+ ncols = crop_ncols
45
+ nrows = math.ceil(len(crops) / ncols)
46
+ fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
47
+ for i, axi in enumerate(ax.flat):
48
+ if i >= len(crops):
49
+ break
50
+ axi.imshow(crops[i])
51
+ crop_fp = '%s-crops.png' % prefix_fp
52
+ plt.savefig(crop_fp)
53
+ print('cropped results are save to file %s' % crop_fp)
54
+
55
+ for info in one_out:
56
+ box, score = info.get('position'), info['score']
57
+ if score < box_score_thresh: # score < 0.5
58
+ continue
59
+ if box is not None:
60
+ box = box.astype(int).reshape(-1, 2)
61
+ cv2.polylines(rotated_img, [box], True, color=(255, 0, 0), thickness=2)
62
+ result_fp = '%s-result.png' % prefix_fp
63
+ imsave(rotated_img, result_fp, normalized=False)
64
+ print('boxes results are save to file %s' % result_fp)
65
+
66
+
67
+ @st.cache(allow_output_mutation=True)
68
+ def get_ocr_model(det_model_name, rec_model_name, det_more_configs):
69
+ det_model_name, det_model_backend = det_model_name
70
+ rec_model_name, rec_model_backend = rec_model_name
71
+ return CnOcr(
72
+ det_model_name=det_model_name,
73
+ det_model_backend=det_model_backend,
74
+ rec_model_name=rec_model_name,
75
+ rec_model_backend=rec_model_backend,
76
+ det_more_configs=det_more_configs,
77
+ )
78
+
79
+
80
+ def visualize_naive_result(img, det_model_name, std_out, box_score_thresh):
81
+ img = pil_to_numpy(img).transpose((1, 2, 0)).astype(np.uint8)
82
+
83
+ plot_for_debugging(img, std_out, box_score_thresh, 2, './streamlit-app')
84
+ st.subheader('Detection Result')
85
+ if det_model_name == 'default_det':
86
+ st.warning('⚠️ Warning: "default_det" 检测模型不返回文本框位置!')
87
+ cols = st.columns([1, 7, 1])
88
+ cols[1].image('./streamlit-app-result.png')
89
+
90
+ st.subheader('Recognition Result')
91
+ cols = st.columns([1, 7, 1])
92
+ cols[1].image('./streamlit-app-crops.png')
93
+
94
+ _visualize_ocr(std_out)
95
+
96
+
97
+ def _visualize_ocr(ocr_outs):
98
+ st.empty()
99
+ ocr_res = OrderedDict({'文本': []})
100
+ ocr_res['得分'] = []
101
+ for out in ocr_outs:
102
+ # cropped_img = out['cropped_img'] # 检测出的文本框
103
+ ocr_res['得分'].append(out['score'])
104
+ ocr_res['文本'].append(out['text'])
105
+ st.table(ocr_res)
106
+
107
+
108
+ def visualize_result(img, ocr_outs):
109
+ out_draw_fp = './streamlit-app-det-result.png'
110
+ font_path = 'docs/fonts/simfang.ttf'
111
+ if not os.path.exists(font_path):
112
+ url = 'https://huggingface.co/datasets/breezedeus/cnocr-wx-qr-code/resolve/main/fonts/simfang.ttf'
113
+ os.makedirs(os.path.dirname(font_path), exist_ok=True)
114
+ download(url, path=font_path, overwrite=True)
115
+ draw_ocr_results(img, ocr_outs, out_draw_fp, font_path)
116
+ st.image(out_draw_fp)
117
+
118
+
119
+ def main():
120
+ st.sidebar.header('模型设置')
121
+ det_models = list(DET_AVAILABLE_MODELS.all_models())
122
+ det_models.append(('naive_det', 'onnx'))
123
+ det_models.sort()
124
+ det_model_name = st.sidebar.selectbox(
125
+ '选择检测模型', det_models, index=det_models.index(('ch_PP-OCRv3_det', 'onnx'))
126
+ )
127
+
128
+ all_models = list(REC_AVAILABLE_MODELS.all_models())
129
+ all_models.sort()
130
+ idx = all_models.index(('densenet_lite_136-fc', 'onnx'))
131
+ rec_model_name = st.sidebar.selectbox('选择识别模型', all_models, index=idx)
132
+
133
+ st.sidebar.subheader('检测参数')
134
+ rotated_bbox = st.sidebar.checkbox('是否检测带角度文本框', value=True)
135
+ use_angle_clf = st.sidebar.checkbox('是否使用角度预测模型校正文��框', value=False)
136
+ new_size = st.sidebar.slider(
137
+ 'resize 后图片(长边)大小', min_value=124, max_value=4096, value=768
138
+ )
139
+ box_score_thresh = st.sidebar.slider(
140
+ '得分阈值(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3
141
+ )
142
+ min_box_size = st.sidebar.slider(
143
+ '框大小阈值(更小的文本框会被过滤掉)', min_value=4, max_value=50, value=10
144
+ )
145
+ # std = get_std_model(det_model_name, rotated_bbox, use_angle_clf)
146
+
147
+ # st.sidebar.markdown("""---""")
148
+ # st.sidebar.header('CnOcr 设置')
149
+ det_more_configs = dict(rotated_bbox=rotated_bbox, use_angle_clf=use_angle_clf)
150
+ ocr = get_ocr_model(det_model_name, rec_model_name, det_more_configs)
151
+
152
+ st.markdown('# 开源Python OCR工具 ' '[CnOCR](https://github.com/breezedeus/cnocr)')
153
+ st.markdown('> 详细说明参见:[CnOCR 文档](https://cnocr.readthedocs.io/) ;'
154
+ '欢迎加入 [交流群](https://cnocr.readthedocs.io/zh/latest/contact/) ;'
155
+ '作者:[breezedeus](https://github.com/breezedeus) 。')
156
+ st.markdown('')
157
+ st.subheader('选择待检测图片')
158
+ content_file = st.file_uploader('', type=["png", "jpg", "jpeg", "webp"])
159
+ if content_file is None:
160
+ st.stop()
161
+
162
+ try:
163
+ img = Image.open(content_file).convert('RGB')
164
+
165
+ ocr_out = ocr.ocr(
166
+ img,
167
+ return_cropped_image=True,
168
+ resized_shape=new_size,
169
+ preserve_aspect_ratio=True,
170
+ box_score_thresh=box_score_thresh,
171
+ min_box_size=min_box_size,
172
+ )
173
+ if det_model_name[0] == 'naive_det':
174
+ visualize_naive_result(img, det_model_name[0], ocr_out, box_score_thresh)
175
+ else:
176
+ visualize_result(img, ocr_out)
177
+
178
+ except Exception as e:
179
+ st.error(e)
180
+
181
+
182
+ if __name__ == '__main__':
183
+ main()