DeepLearning101 commited on
Commit
fa6fa48
1 Parent(s): c93900f

Upload 39 files

Browse files
Files changed (39) hide show
  1. tools/__init__.py +14 -0
  2. tools/__pycache__/__init__.cpython-37.pyc +0 -0
  3. tools/__pycache__/__init__.cpython-38.pyc +0 -0
  4. tools/end2end/convert_ppocr_label.py +100 -0
  5. tools/end2end/draw_html.py +73 -0
  6. tools/end2end/eval_end2end.py +193 -0
  7. tools/end2end/readme.md +63 -0
  8. tools/eval.py +137 -0
  9. tools/export_center.py +77 -0
  10. tools/export_model.py +269 -0
  11. tools/infer/__pycache__/predict_cls.cpython-37.pyc +0 -0
  12. tools/infer/__pycache__/predict_cls.cpython-38.pyc +0 -0
  13. tools/infer/__pycache__/predict_det.cpython-37.pyc +0 -0
  14. tools/infer/__pycache__/predict_det.cpython-38.pyc +0 -0
  15. tools/infer/__pycache__/predict_rec.cpython-37.pyc +0 -0
  16. tools/infer/__pycache__/predict_rec.cpython-38.pyc +0 -0
  17. tools/infer/__pycache__/predict_system.cpython-37.pyc +0 -0
  18. tools/infer/__pycache__/predict_system.cpython-38.pyc +0 -0
  19. tools/infer/__pycache__/utility.cpython-37.pyc +0 -0
  20. tools/infer/__pycache__/utility.cpython-38.pyc +0 -0
  21. tools/infer/predict_cls.py +151 -0
  22. tools/infer/predict_det.py +353 -0
  23. tools/infer/predict_e2e.py +169 -0
  24. tools/infer/predict_rec.py +667 -0
  25. tools/infer/predict_sr.py +155 -0
  26. tools/infer/predict_system.py +262 -0
  27. tools/infer/utility.py +663 -0
  28. tools/infer_cls.py +85 -0
  29. tools/infer_det.py +134 -0
  30. tools/infer_e2e.py +174 -0
  31. tools/infer_kie.py +176 -0
  32. tools/infer_kie_token_ser.py +157 -0
  33. tools/infer_kie_token_ser_re.py +225 -0
  34. tools/infer_rec.py +188 -0
  35. tools/infer_sr.py +100 -0
  36. tools/infer_table.py +121 -0
  37. tools/program.py +702 -0
  38. tools/test_hubserving.py +157 -0
  39. tools/train.py +209 -0
tools/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
tools/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (107 Bytes). View file
 
tools/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (144 Bytes). View file
 
tools/end2end/convert_ppocr_label.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import json
16
+ import os
17
+
18
+
19
+ def poly_to_string(poly):
20
+ if len(poly.shape) > 1:
21
+ poly = np.array(poly).flatten()
22
+
23
+ string = "\t".join(str(i) for i in poly)
24
+ return string
25
+
26
+
27
+ def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
28
+ if not os.path.exists(label_dir):
29
+ raise ValueError(f"The file {label_dir} does not exist!")
30
+
31
+ assert label_dir != save_dir, "hahahhaha"
32
+
33
+ label_file = open(label_dir, 'r')
34
+ data = label_file.readlines()
35
+
36
+ gt_dict = {}
37
+
38
+ for line in data:
39
+ try:
40
+ tmp = line.split('\t')
41
+ assert len(tmp) == 2, ""
42
+ except:
43
+ tmp = line.strip().split(' ')
44
+
45
+ gt_lists = []
46
+
47
+ if tmp[0].split('/')[0] is not None:
48
+ img_path = tmp[0]
49
+ anno = json.loads(tmp[1])
50
+ gt_collect = []
51
+ for dic in anno:
52
+ #txt = dic['transcription'].replace(' ', '') # ignore blank
53
+ txt = dic['transcription']
54
+ if 'score' in dic and float(dic['score']) < 0.5:
55
+ continue
56
+ if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
57
+ #while ' ' in txt:
58
+ # txt = txt.replace(' ', '')
59
+ poly = np.array(dic['points']).flatten()
60
+ if txt == "###":
61
+ txt_tag = 1 ## ignore 1
62
+ else:
63
+ txt_tag = 0
64
+ if mode == "gt":
65
+ gt_label = poly_to_string(poly) + "\t" + str(
66
+ txt_tag) + "\t" + txt + "\n"
67
+ else:
68
+ gt_label = poly_to_string(poly) + "\t" + txt + "\n"
69
+
70
+ gt_lists.append(gt_label)
71
+
72
+ gt_dict[img_path] = gt_lists
73
+ else:
74
+ continue
75
+
76
+ if not os.path.exists(save_dir):
77
+ os.makedirs(save_dir)
78
+
79
+ for img_name in gt_dict.keys():
80
+ save_name = img_name.split("/")[-1]
81
+ save_file = os.path.join(save_dir, save_name + ".txt")
82
+ with open(save_file, "w") as f:
83
+ f.writelines(gt_dict[img_name])
84
+
85
+ print("The convert label saved in {}".format(save_dir))
86
+
87
+
88
+ def parse_args():
89
+ import argparse
90
+ parser = argparse.ArgumentParser(description="args")
91
+ parser.add_argument("--label_path", type=str, required=True)
92
+ parser.add_argument("--save_folder", type=str, required=True)
93
+ parser.add_argument("--mode", type=str, default=False)
94
+ args = parser.parse_args()
95
+ return args
96
+
97
+
98
+ if __name__ == "__main__":
99
+ args = parse_args()
100
+ convert_label(args.label_path, args.mode, args.save_folder)
tools/end2end/draw_html.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import argparse
17
+
18
+
19
+ def str2bool(v):
20
+ return v.lower() in ("true", "t", "1")
21
+
22
+
23
+ def init_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--image_dir", type=str, default="")
26
+ parser.add_argument("--save_html_path", type=str, default="./default.html")
27
+ parser.add_argument("--width", type=int, default=640)
28
+ return parser
29
+
30
+
31
+ def parse_args():
32
+ parser = init_args()
33
+ return parser.parse_args()
34
+
35
+
36
+ def draw_debug_img(args):
37
+
38
+ html_path = args.save_html_path
39
+
40
+ err_cnt = 0
41
+ with open(html_path, 'w') as html:
42
+ html.write('<html>\n<body>\n')
43
+ html.write('<table border="1">\n')
44
+ html.write(
45
+ "<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
46
+ )
47
+ image_list = []
48
+ path = args.image_dir
49
+ for i, filename in enumerate(sorted(os.listdir(path))):
50
+ if filename.endswith("txt"): continue
51
+ # The image path
52
+ base = "{}/{}".format(path, filename)
53
+ html.write("<tr>\n")
54
+ html.write(f'<td> {filename}\n GT')
55
+ html.write(f'<td>GT\n<img src="{base}" width={args.width}></td>')
56
+
57
+ html.write("</tr>\n")
58
+ html.write('<style>\n')
59
+ html.write('span {\n')
60
+ html.write(' color: red;\n')
61
+ html.write('}\n')
62
+ html.write('</style>\n')
63
+ html.write('</table>\n')
64
+ html.write('</html>\n</body>\n')
65
+ print(f"The html file saved in {html_path}")
66
+ return
67
+
68
+
69
+ if __name__ == "__main__":
70
+
71
+ args = parse_args()
72
+
73
+ draw_debug_img(args)
tools/end2end/eval_end2end.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ import sys
18
+ import shapely
19
+ from shapely.geometry import Polygon
20
+ import numpy as np
21
+ from collections import defaultdict
22
+ import operator
23
+ import editdistance
24
+
25
+
26
+ def strQ2B(ustring):
27
+ rstring = ""
28
+ for uchar in ustring:
29
+ inside_code = ord(uchar)
30
+ if inside_code == 12288:
31
+ inside_code = 32
32
+ elif (inside_code >= 65281 and inside_code <= 65374):
33
+ inside_code -= 65248
34
+ rstring += chr(inside_code)
35
+ return rstring
36
+
37
+
38
+ def polygon_from_str(polygon_points):
39
+ """
40
+ Create a shapely polygon object from gt or dt line.
41
+ """
42
+ polygon_points = np.array(polygon_points).reshape(4, 2)
43
+ polygon = Polygon(polygon_points).convex_hull
44
+ return polygon
45
+
46
+
47
+ def polygon_iou(poly1, poly2):
48
+ """
49
+ Intersection over union between two shapely polygons.
50
+ """
51
+ if not poly1.intersects(
52
+ poly2): # this test is fast and can accelerate calculation
53
+ iou = 0
54
+ else:
55
+ try:
56
+ inter_area = poly1.intersection(poly2).area
57
+ union_area = poly1.area + poly2.area - inter_area
58
+ iou = float(inter_area) / union_area
59
+ except shapely.geos.TopologicalError:
60
+ # except Exception as e:
61
+ # print(e)
62
+ print('shapely.geos.TopologicalError occurred, iou set to 0')
63
+ iou = 0
64
+ return iou
65
+
66
+
67
+ def ed(str1, str2):
68
+ return editdistance.eval(str1, str2)
69
+
70
+
71
+ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
72
+ print('start testing...')
73
+ iou_thresh = 0.5
74
+ val_names = os.listdir(gt_dir)
75
+ num_gt_chars = 0
76
+ gt_count = 0
77
+ dt_count = 0
78
+ hit = 0
79
+ ed_sum = 0
80
+
81
+ for i, val_name in enumerate(val_names):
82
+ with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
83
+ gt_lines = [o.strip() for o in f.readlines()]
84
+ gts = []
85
+ ignore_masks = []
86
+ for line in gt_lines:
87
+ parts = line.strip().split('\t')
88
+ # ignore illegal data
89
+ if len(parts) < 9:
90
+ continue
91
+ assert (len(parts) < 11)
92
+ if len(parts) == 9:
93
+ gts.append(parts[:8] + [''])
94
+ else:
95
+ gts.append(parts[:8] + [parts[-1]])
96
+
97
+ ignore_masks.append(parts[8])
98
+
99
+ val_path = os.path.join(res_dir, val_name)
100
+ if not os.path.exists(val_path):
101
+ dt_lines = []
102
+ else:
103
+ with open(val_path, encoding='utf-8') as f:
104
+ dt_lines = [o.strip() for o in f.readlines()]
105
+ dts = []
106
+ for line in dt_lines:
107
+ # print(line)
108
+ parts = line.strip().split("\t")
109
+ assert (len(parts) < 10), "line error: {}".format(line)
110
+ if len(parts) == 8:
111
+ dts.append(parts + [''])
112
+ else:
113
+ dts.append(parts)
114
+
115
+ dt_match = [False] * len(dts)
116
+ gt_match = [False] * len(gts)
117
+ all_ious = defaultdict(tuple)
118
+ for index_gt, gt in enumerate(gts):
119
+ gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
120
+ gt_poly = polygon_from_str(gt_coors)
121
+ for index_dt, dt in enumerate(dts):
122
+ dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
123
+ dt_poly = polygon_from_str(dt_coors)
124
+ iou = polygon_iou(dt_poly, gt_poly)
125
+ if iou >= iou_thresh:
126
+ all_ious[(index_gt, index_dt)] = iou
127
+ sorted_ious = sorted(
128
+ all_ious.items(), key=operator.itemgetter(1), reverse=True)
129
+ sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
130
+
131
+ # matched gt and dt
132
+ for gt_dt_pair in sorted_gt_dt_pairs:
133
+ index_gt, index_dt = gt_dt_pair
134
+ if gt_match[index_gt] == False and dt_match[index_dt] == False:
135
+ gt_match[index_gt] = True
136
+ dt_match[index_dt] = True
137
+ if ignore_blank:
138
+ gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
139
+ dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
140
+ else:
141
+ gt_str = strQ2B(gts[index_gt][8])
142
+ dt_str = strQ2B(dts[index_dt][8])
143
+ if ignore_masks[index_gt] == '0':
144
+ ed_sum += ed(gt_str, dt_str)
145
+ num_gt_chars += len(gt_str)
146
+ if gt_str == dt_str:
147
+ hit += 1
148
+ gt_count += 1
149
+ dt_count += 1
150
+
151
+ # unmatched dt
152
+ for tindex, dt_match_flag in enumerate(dt_match):
153
+ if dt_match_flag == False:
154
+ dt_str = dts[tindex][8]
155
+ gt_str = ''
156
+ ed_sum += ed(dt_str, gt_str)
157
+ dt_count += 1
158
+
159
+ # unmatched gt
160
+ for tindex, gt_match_flag in enumerate(gt_match):
161
+ if gt_match_flag == False and ignore_masks[tindex] == '0':
162
+ dt_str = ''
163
+ gt_str = gts[tindex][8]
164
+ ed_sum += ed(gt_str, dt_str)
165
+ num_gt_chars += len(gt_str)
166
+ gt_count += 1
167
+
168
+ eps = 1e-9
169
+ print('hit, dt_count, gt_count', hit, dt_count, gt_count)
170
+ precision = hit / (dt_count + eps)
171
+ recall = hit / (gt_count + eps)
172
+ fmeasure = 2.0 * precision * recall / (precision + recall + eps)
173
+ avg_edit_dist_img = ed_sum / len(val_names)
174
+ avg_edit_dist_field = ed_sum / (gt_count + eps)
175
+ character_acc = 1 - ed_sum / (num_gt_chars + eps)
176
+
177
+ print('character_acc: %.2f' % (character_acc * 100) + "%")
178
+ print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
179
+ print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
180
+ print('precision: %.2f' % (precision * 100) + "%")
181
+ print('recall: %.2f' % (recall * 100) + "%")
182
+ print('fmeasure: %.2f' % (fmeasure * 100) + "%")
183
+
184
+
185
+ if __name__ == '__main__':
186
+ # if len(sys.argv) != 3:
187
+ # print("python3 ocr_e2e_eval.py gt_dir res_dir")
188
+ # exit(-1)
189
+ # gt_folder = sys.argv[1]
190
+ # pred_folder = sys.argv[2]
191
+ gt_folder = sys.argv[1]
192
+ pred_folder = sys.argv[2]
193
+ e2e_eval(gt_folder, pred_folder)
tools/end2end/readme.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 简介
3
+
4
+ `tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
5
+
6
+
7
+ ## 端对端评测步骤
8
+
9
+ **步骤一:**
10
+
11
+ 运行`tools/infer/predict_system.py`,得到保存的结果:
12
+
13
+ ```
14
+ python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
15
+ ```
16
+
17
+ 文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/system_results.txt`中,格式如下:
18
+ ```
19
+ all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
20
+ ```
21
+
22
+
23
+ **步骤二:**
24
+
25
+ 将步骤一保存的数据转换为端对端评测需要的数据格式:
26
+
27
+ 修改 `tools/end2end/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
28
+
29
+ ```
30
+ python3 tools/end2end/convert_ppocr_label.py --mode=gt --label_path=path/to/label_txt --save_folder=save_gt_label
31
+
32
+ python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=path/to/pred_txt --save_folder=save_PPOCRV2_infer
33
+ ```
34
+
35
+ 得到如下结果:
36
+ ```
37
+ ├── ./save_gt_label/
38
+ ├── ./save_PPOCRV2_infer/
39
+ ```
40
+
41
+ **步骤三:**
42
+
43
+ 执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
44
+
45
+ ```
46
+ python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
47
+ ```
48
+
49
+ 比如:
50
+
51
+ ```
52
+ python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
53
+ ```
54
+ 将得到如下结果,fmeasure为主要关注的指标:
55
+ ```
56
+ hit, dt_count, gt_count 1557 2693 3283
57
+ character_acc: 61.77%
58
+ avg_edit_dist_field: 3.08
59
+ avg_edit_dist_img: 51.82
60
+ precision: 57.82%
61
+ recall: 47.43%
62
+ fmeasure: 52.11%
63
+ ```
tools/eval.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import os
20
+ import sys
21
+
22
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
23
+ sys.path.insert(0, __dir__)
24
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
25
+
26
+ import paddle
27
+ from ppocr.data import build_dataloader
28
+ from ppocr.modeling.architectures import build_model
29
+ from ppocr.postprocess import build_post_process
30
+ from ppocr.metrics import build_metric
31
+ from ppocr.utils.save_load import load_model
32
+ import tools.program as program
33
+
34
+
35
+ def main():
36
+ global_config = config['Global']
37
+ # build dataloader
38
+ valid_dataloader = build_dataloader(config, 'Eval', device, logger)
39
+
40
+ # build post process
41
+ post_process_class = build_post_process(config['PostProcess'],
42
+ global_config)
43
+
44
+ # build model
45
+ # for rec algorithm
46
+ if hasattr(post_process_class, 'character'):
47
+ char_num = len(getattr(post_process_class, 'character'))
48
+ if config['Architecture']["algorithm"] in ["Distillation",
49
+ ]: # distillation model
50
+ for key in config['Architecture']["Models"]:
51
+ if config['Architecture']['Models'][key]['Head'][
52
+ 'name'] == 'MultiHead': # for multi head
53
+ out_channels_list = {}
54
+ if config['PostProcess'][
55
+ 'name'] == 'DistillationSARLabelDecode':
56
+ char_num = char_num - 2
57
+ out_channels_list['CTCLabelDecode'] = char_num
58
+ out_channels_list['SARLabelDecode'] = char_num + 2
59
+ config['Architecture']['Models'][key]['Head'][
60
+ 'out_channels_list'] = out_channels_list
61
+ else:
62
+ config['Architecture']["Models"][key]["Head"][
63
+ 'out_channels'] = char_num
64
+ elif config['Architecture']['Head'][
65
+ 'name'] == 'MultiHead': # for multi head
66
+ out_channels_list = {}
67
+ if config['PostProcess']['name'] == 'SARLabelDecode':
68
+ char_num = char_num - 2
69
+ out_channels_list['CTCLabelDecode'] = char_num
70
+ out_channels_list['SARLabelDecode'] = char_num + 2
71
+ config['Architecture']['Head'][
72
+ 'out_channels_list'] = out_channels_list
73
+ else: # base rec model
74
+ config['Architecture']["Head"]['out_channels'] = char_num
75
+
76
+ model = build_model(config['Architecture'])
77
+ extra_input_models = [
78
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
79
+ ]
80
+ extra_input = False
81
+ if config['Architecture']['algorithm'] == 'Distillation':
82
+ for key in config['Architecture']["Models"]:
83
+ extra_input = extra_input or config['Architecture']['Models'][key][
84
+ 'algorithm'] in extra_input_models
85
+ else:
86
+ extra_input = config['Architecture']['algorithm'] in extra_input_models
87
+ if "model_type" in config['Architecture'].keys():
88
+ if config['Architecture']['algorithm'] == 'CAN':
89
+ model_type = 'can'
90
+ else:
91
+ model_type = config['Architecture']['model_type']
92
+ else:
93
+ model_type = None
94
+
95
+ # build metric
96
+ eval_class = build_metric(config['Metric'])
97
+ # amp
98
+ use_amp = config["Global"].get("use_amp", False)
99
+ amp_level = config["Global"].get("amp_level", 'O2')
100
+ amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
101
+ if use_amp:
102
+ AMP_RELATED_FLAGS_SETTING = {
103
+ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
104
+ 'FLAGS_max_inplace_grad_add': 8,
105
+ }
106
+ paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
107
+ scale_loss = config["Global"].get("scale_loss", 1.0)
108
+ use_dynamic_loss_scaling = config["Global"].get(
109
+ "use_dynamic_loss_scaling", False)
110
+ scaler = paddle.amp.GradScaler(
111
+ init_loss_scaling=scale_loss,
112
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling)
113
+ if amp_level == "O2":
114
+ model = paddle.amp.decorate(
115
+ models=model, level=amp_level, master_weight=True)
116
+ else:
117
+ scaler = None
118
+
119
+ best_model_dict = load_model(
120
+ config, model, model_type=config['Architecture']["model_type"])
121
+ if len(best_model_dict):
122
+ logger.info('metric in ckpt ***************')
123
+ for k, v in best_model_dict.items():
124
+ logger.info('{}:{}'.format(k, v))
125
+
126
+ # start eval
127
+ metric = program.eval(model, valid_dataloader, post_process_class,
128
+ eval_class, model_type, extra_input, scaler,
129
+ amp_level, amp_custom_black_list)
130
+ logger.info('metric eval ***************')
131
+ for k, v in metric.items():
132
+ logger.info('{}:{}'.format(k, v))
133
+
134
+
135
+ if __name__ == '__main__':
136
+ config, device, logger, vdl_writer = program.preprocess()
137
+ main()
tools/export_center.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import os
20
+ import sys
21
+ import pickle
22
+
23
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append(__dir__)
25
+ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
26
+
27
+ from ppocr.data import build_dataloader
28
+ from ppocr.modeling.architectures import build_model
29
+ from ppocr.postprocess import build_post_process
30
+ from ppocr.utils.save_load import load_model
31
+ from ppocr.utils.utility import print_dict
32
+ import tools.program as program
33
+
34
+
35
+ def main():
36
+ global_config = config['Global']
37
+ # build dataloader
38
+ config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
39
+ config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
40
+ 'data_dir']
41
+ config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
42
+ 'label_file_list']
43
+ eval_dataloader = build_dataloader(config, 'Eval', device, logger)
44
+
45
+ # build post process
46
+ post_process_class = build_post_process(config['PostProcess'],
47
+ global_config)
48
+
49
+ # build model
50
+ # for rec algorithm
51
+ if hasattr(post_process_class, 'character'):
52
+ char_num = len(getattr(post_process_class, 'character'))
53
+ config['Architecture']["Head"]['out_channels'] = char_num
54
+
55
+ #set return_features = True
56
+ config['Architecture']["Head"]["return_feats"] = True
57
+
58
+ model = build_model(config['Architecture'])
59
+
60
+ best_model_dict = load_model(config, model)
61
+ if len(best_model_dict):
62
+ logger.info('metric in ckpt ***************')
63
+ for k, v in best_model_dict.items():
64
+ logger.info('{}:{}'.format(k, v))
65
+
66
+ # get features from train data
67
+ char_center = program.get_center(model, eval_dataloader, post_process_class)
68
+
69
+ #serialize to disk
70
+ with open("train_center.pkl", 'wb') as f:
71
+ pickle.dump(char_center, f)
72
+ return
73
+
74
+
75
+ if __name__ == '__main__':
76
+ config, device, logger, vdl_writer = program.preprocess()
77
+ main()
tools/export_model.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+
18
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.append(__dir__)
20
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
21
+
22
+ import argparse
23
+
24
+ import paddle
25
+ from paddle.jit import to_static
26
+
27
+ from ppocr.modeling.architectures import build_model
28
+ from ppocr.postprocess import build_post_process
29
+ from ppocr.utils.save_load import load_model
30
+ from ppocr.utils.logging import get_logger
31
+ from tools.program import load_config, merge_config, ArgsParser
32
+
33
+
34
+ def export_single_model(model,
35
+ arch_config,
36
+ save_path,
37
+ logger,
38
+ input_shape=None,
39
+ quanter=None):
40
+ if arch_config["algorithm"] == "SRN":
41
+ max_text_length = arch_config["Head"]["max_text_length"]
42
+ other_shape = [
43
+ paddle.static.InputSpec(
44
+ shape=[None, 1, 64, 256], dtype="float32"), [
45
+ paddle.static.InputSpec(
46
+ shape=[None, 256, 1],
47
+ dtype="int64"), paddle.static.InputSpec(
48
+ shape=[None, max_text_length, 1], dtype="int64"),
49
+ paddle.static.InputSpec(
50
+ shape=[None, 8, max_text_length, max_text_length],
51
+ dtype="int64"), paddle.static.InputSpec(
52
+ shape=[None, 8, max_text_length, max_text_length],
53
+ dtype="int64")
54
+ ]
55
+ ]
56
+ model = to_static(model, input_spec=other_shape)
57
+ elif arch_config["algorithm"] == "SAR":
58
+ other_shape = [
59
+ paddle.static.InputSpec(
60
+ shape=[None, 3, 48, 160], dtype="float32"),
61
+ [paddle.static.InputSpec(
62
+ shape=[None], dtype="float32")]
63
+ ]
64
+ model = to_static(model, input_spec=other_shape)
65
+ elif arch_config["algorithm"] == "SVTR":
66
+ if arch_config["Head"]["name"] == 'MultiHead':
67
+ other_shape = [
68
+ paddle.static.InputSpec(
69
+ shape=[None, 3, 48, -1], dtype="float32"),
70
+ ]
71
+ else:
72
+ other_shape = [
73
+ paddle.static.InputSpec(
74
+ shape=[None] + input_shape, dtype="float32"),
75
+ ]
76
+ model = to_static(model, input_spec=other_shape)
77
+ elif arch_config["algorithm"] == "PREN":
78
+ other_shape = [
79
+ paddle.static.InputSpec(
80
+ shape=[None, 3, 64, 256], dtype="float32"),
81
+ ]
82
+ model = to_static(model, input_spec=other_shape)
83
+ elif arch_config["model_type"] == "sr":
84
+ other_shape = [
85
+ paddle.static.InputSpec(
86
+ shape=[None, 3, 16, 64], dtype="float32")
87
+ ]
88
+ model = to_static(model, input_spec=other_shape)
89
+ elif arch_config["algorithm"] == "ViTSTR":
90
+ other_shape = [
91
+ paddle.static.InputSpec(
92
+ shape=[None, 1, 224, 224], dtype="float32"),
93
+ ]
94
+ model = to_static(model, input_spec=other_shape)
95
+ elif arch_config["algorithm"] == "ABINet":
96
+ other_shape = [
97
+ paddle.static.InputSpec(
98
+ shape=[None, 3, 32, 128], dtype="float32"),
99
+ ]
100
+ # print([None, 3, 32, 128])
101
+ model = to_static(model, input_spec=other_shape)
102
+ elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
103
+ other_shape = [
104
+ paddle.static.InputSpec(
105
+ shape=[None, 1, 32, 100], dtype="float32"),
106
+ ]
107
+ model = to_static(model, input_spec=other_shape)
108
+ elif arch_config["algorithm"] == "VisionLAN":
109
+ other_shape = [
110
+ paddle.static.InputSpec(
111
+ shape=[None, 3, 64, 256], dtype="float32"),
112
+ ]
113
+ model = to_static(model, input_spec=other_shape)
114
+ elif arch_config["algorithm"] == "RobustScanner":
115
+ max_text_length = arch_config["Head"]["max_text_length"]
116
+ other_shape = [
117
+ paddle.static.InputSpec(
118
+ shape=[None, 3, 48, 160], dtype="float32"), [
119
+ paddle.static.InputSpec(
120
+ shape=[None, ], dtype="float32"),
121
+ paddle.static.InputSpec(
122
+ shape=[None, max_text_length], dtype="int64")
123
+ ]
124
+ ]
125
+ model = to_static(model, input_spec=other_shape)
126
+ elif arch_config["algorithm"] == "CAN":
127
+ other_shape = [[
128
+ paddle.static.InputSpec(
129
+ shape=[None, 1, None, None],
130
+ dtype="float32"), paddle.static.InputSpec(
131
+ shape=[None, 1, None, None], dtype="float32"),
132
+ paddle.static.InputSpec(
133
+ shape=[None, arch_config['Head']['max_text_length']],
134
+ dtype="int64")
135
+ ]]
136
+ model = to_static(model, input_spec=other_shape)
137
+ elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
138
+ input_spec = [
139
+ paddle.static.InputSpec(
140
+ shape=[None, 512], dtype="int64"), # input_ids
141
+ paddle.static.InputSpec(
142
+ shape=[None, 512, 4], dtype="int64"), # bbox
143
+ paddle.static.InputSpec(
144
+ shape=[None, 512], dtype="int64"), # attention_mask
145
+ paddle.static.InputSpec(
146
+ shape=[None, 512], dtype="int64"), # token_type_ids
147
+ paddle.static.InputSpec(
148
+ shape=[None, 3, 224, 224], dtype="int64"), # image
149
+ ]
150
+ if 'Re' in arch_config['Backbone']['name']:
151
+ input_spec.extend([
152
+ paddle.static.InputSpec(
153
+ shape=[None, 512, 3], dtype="int64"), # entities
154
+ paddle.static.InputSpec(
155
+ shape=[None, None, 2], dtype="int64"), # relations
156
+ ])
157
+ if model.backbone.use_visual_backbone is False:
158
+ input_spec.pop(4)
159
+ model = to_static(model, input_spec=[input_spec])
160
+ else:
161
+ infer_shape = [3, -1, -1]
162
+ if arch_config["model_type"] == "rec":
163
+ infer_shape = [3, 32, -1] # for rec model, H must be 32
164
+ if "Transform" in arch_config and arch_config[
165
+ "Transform"] is not None and arch_config["Transform"][
166
+ "name"] == "TPS":
167
+ logger.info(
168
+ "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
169
+ )
170
+ infer_shape[-1] = 100
171
+ elif arch_config["model_type"] == "table":
172
+ infer_shape = [3, 488, 488]
173
+ if arch_config["algorithm"] == "TableMaster":
174
+ infer_shape = [3, 480, 480]
175
+ if arch_config["algorithm"] == "SLANet":
176
+ infer_shape = [3, -1, -1]
177
+ model = to_static(
178
+ model,
179
+ input_spec=[
180
+ paddle.static.InputSpec(
181
+ shape=[None] + infer_shape, dtype="float32")
182
+ ])
183
+
184
+ if quanter is None:
185
+ paddle.jit.save(model, save_path)
186
+ else:
187
+ quanter.save_quantized_model(model, save_path)
188
+ logger.info("inference model is saved to {}".format(save_path))
189
+ return
190
+
191
+
192
+ def main():
193
+ FLAGS = ArgsParser().parse_args()
194
+ config = load_config(FLAGS.config)
195
+ config = merge_config(config, FLAGS.opt)
196
+ logger = get_logger()
197
+ # build post process
198
+
199
+ post_process_class = build_post_process(config["PostProcess"],
200
+ config["Global"])
201
+
202
+ # build model
203
+ # for rec algorithm
204
+ if hasattr(post_process_class, "character"):
205
+ char_num = len(getattr(post_process_class, "character"))
206
+ if config["Architecture"]["algorithm"] in ["Distillation",
207
+ ]: # distillation model
208
+ for key in config["Architecture"]["Models"]:
209
+ if config["Architecture"]["Models"][key]["Head"][
210
+ "name"] == 'MultiHead': # multi head
211
+ out_channels_list = {}
212
+ if config['PostProcess'][
213
+ 'name'] == 'DistillationSARLabelDecode':
214
+ char_num = char_num - 2
215
+ out_channels_list['CTCLabelDecode'] = char_num
216
+ out_channels_list['SARLabelDecode'] = char_num + 2
217
+ config['Architecture']['Models'][key]['Head'][
218
+ 'out_channels_list'] = out_channels_list
219
+ else:
220
+ config["Architecture"]["Models"][key]["Head"][
221
+ "out_channels"] = char_num
222
+ # just one final tensor needs to exported for inference
223
+ config["Architecture"]["Models"][key][
224
+ "return_all_feats"] = False
225
+ elif config['Architecture']['Head'][
226
+ 'name'] == 'MultiHead': # multi head
227
+ out_channels_list = {}
228
+ char_num = len(getattr(post_process_class, 'character'))
229
+ if config['PostProcess']['name'] == 'SARLabelDecode':
230
+ char_num = char_num - 2
231
+ out_channels_list['CTCLabelDecode'] = char_num
232
+ out_channels_list['SARLabelDecode'] = char_num + 2
233
+ config['Architecture']['Head'][
234
+ 'out_channels_list'] = out_channels_list
235
+ else: # base rec model
236
+ config["Architecture"]["Head"]["out_channels"] = char_num
237
+
238
+ # for sr algorithm
239
+ if config["Architecture"]["model_type"] == "sr":
240
+ config['Architecture']["Transform"]['infer_mode'] = True
241
+ model = build_model(config["Architecture"])
242
+ load_model(config, model, model_type=config['Architecture']["model_type"])
243
+ model.eval()
244
+
245
+ save_path = config["Global"]["save_inference_dir"]
246
+
247
+ arch_config = config["Architecture"]
248
+
249
+ if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
250
+ "name"] != 'MultiHead':
251
+ input_shape = config["Eval"]["dataset"]["transforms"][-2][
252
+ 'SVTRRecResizeImg']['image_shape']
253
+ else:
254
+ input_shape = None
255
+
256
+ if arch_config["algorithm"] in ["Distillation", ]: # distillation model
257
+ archs = list(arch_config["Models"].values())
258
+ for idx, name in enumerate(model.model_name_list):
259
+ sub_model_save_path = os.path.join(save_path, name, "inference")
260
+ export_single_model(model.model_list[idx], archs[idx],
261
+ sub_model_save_path, logger)
262
+ else:
263
+ save_path = os.path.join(save_path, "inference")
264
+ export_single_model(
265
+ model, arch_config, save_path, logger, input_shape=input_shape)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
tools/infer/__pycache__/predict_cls.cpython-37.pyc ADDED
Binary file (4.07 kB). View file
 
tools/infer/__pycache__/predict_cls.cpython-38.pyc ADDED
Binary file (4.11 kB). View file
 
tools/infer/__pycache__/predict_det.cpython-37.pyc ADDED
Binary file (8.48 kB). View file
 
tools/infer/__pycache__/predict_det.cpython-38.pyc ADDED
Binary file (8.61 kB). View file
 
tools/infer/__pycache__/predict_rec.cpython-37.pyc ADDED
Binary file (13.8 kB). View file
 
tools/infer/__pycache__/predict_rec.cpython-38.pyc ADDED
Binary file (13.8 kB). View file
 
tools/infer/__pycache__/predict_system.cpython-37.pyc ADDED
Binary file (7.04 kB). View file
 
tools/infer/__pycache__/predict_system.cpython-38.pyc ADDED
Binary file (7.13 kB). View file
 
tools/infer/__pycache__/utility.cpython-37.pyc ADDED
Binary file (17.4 kB). View file
 
tools/infer/__pycache__/utility.cpython-38.pyc ADDED
Binary file (17.4 kB). View file
 
tools/infer/predict_cls.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+
17
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(__dir__)
19
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20
+
21
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22
+
23
+ import cv2
24
+ import copy
25
+ import numpy as np
26
+ import math
27
+ import time
28
+ import traceback
29
+
30
+ import tools.infer.utility as utility
31
+ from ppocr.postprocess import build_post_process
32
+ from ppocr.utils.logging import get_logger
33
+ from ppocr.utils.utility import get_image_file_list, check_and_read
34
+
35
+ logger = get_logger()
36
+
37
+
38
+ class TextClassifier(object):
39
+ def __init__(self, args):
40
+ self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
41
+ self.cls_batch_num = args.cls_batch_num
42
+ self.cls_thresh = args.cls_thresh
43
+ postprocess_params = {
44
+ 'name': 'ClsPostProcess',
45
+ "label_list": args.label_list,
46
+ }
47
+ self.postprocess_op = build_post_process(postprocess_params)
48
+ self.predictor, self.input_tensor, self.output_tensors, _ = \
49
+ utility.create_predictor(args, 'cls', logger)
50
+ self.use_onnx = args.use_onnx
51
+
52
+ def resize_norm_img(self, img):
53
+ imgC, imgH, imgW = self.cls_image_shape
54
+ h = img.shape[0]
55
+ w = img.shape[1]
56
+ ratio = w / float(h)
57
+ if math.ceil(imgH * ratio) > imgW:
58
+ resized_w = imgW
59
+ else:
60
+ resized_w = int(math.ceil(imgH * ratio))
61
+ resized_image = cv2.resize(img, (resized_w, imgH))
62
+ resized_image = resized_image.astype('float32')
63
+ if self.cls_image_shape[0] == 1:
64
+ resized_image = resized_image / 255
65
+ resized_image = resized_image[np.newaxis, :]
66
+ else:
67
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
68
+ resized_image -= 0.5
69
+ resized_image /= 0.5
70
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
71
+ padding_im[:, :, 0:resized_w] = resized_image
72
+ return padding_im
73
+
74
+ def __call__(self, img_list):
75
+ img_list = copy.deepcopy(img_list)
76
+ img_num = len(img_list)
77
+ # Calculate the aspect ratio of all text bars
78
+ width_list = []
79
+ for img in img_list:
80
+ width_list.append(img.shape[1] / float(img.shape[0]))
81
+ # Sorting can speed up the cls process
82
+ indices = np.argsort(np.array(width_list))
83
+
84
+ cls_res = [['', 0.0]] * img_num
85
+ batch_num = self.cls_batch_num
86
+ elapse = 0
87
+ for beg_img_no in range(0, img_num, batch_num):
88
+
89
+ end_img_no = min(img_num, beg_img_no + batch_num)
90
+ norm_img_batch = []
91
+ max_wh_ratio = 0
92
+ starttime = time.time()
93
+ for ino in range(beg_img_no, end_img_no):
94
+ h, w = img_list[indices[ino]].shape[0:2]
95
+ wh_ratio = w * 1.0 / h
96
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
97
+ for ino in range(beg_img_no, end_img_no):
98
+ norm_img = self.resize_norm_img(img_list[indices[ino]])
99
+ norm_img = norm_img[np.newaxis, :]
100
+ norm_img_batch.append(norm_img)
101
+ norm_img_batch = np.concatenate(norm_img_batch)
102
+ norm_img_batch = norm_img_batch.copy()
103
+
104
+ if self.use_onnx:
105
+ input_dict = {}
106
+ input_dict[self.input_tensor.name] = norm_img_batch
107
+ outputs = self.predictor.run(self.output_tensors, input_dict)
108
+ prob_out = outputs[0]
109
+ else:
110
+ self.input_tensor.copy_from_cpu(norm_img_batch)
111
+ self.predictor.run()
112
+ prob_out = self.output_tensors[0].copy_to_cpu()
113
+ self.predictor.try_shrink_memory()
114
+ cls_result = self.postprocess_op(prob_out)
115
+ elapse += time.time() - starttime
116
+ for rno in range(len(cls_result)):
117
+ label, score = cls_result[rno]
118
+ cls_res[indices[beg_img_no + rno]] = [label, score]
119
+ if '180' in label and score > self.cls_thresh:
120
+ img_list[indices[beg_img_no + rno]] = cv2.rotate(
121
+ img_list[indices[beg_img_no + rno]], 1)
122
+ return img_list, cls_res, elapse
123
+
124
+
125
+ def main(args):
126
+ image_file_list = get_image_file_list(args.image_dir)
127
+ text_classifier = TextClassifier(args)
128
+ valid_image_file_list = []
129
+ img_list = []
130
+ for image_file in image_file_list:
131
+ img, flag, _ = check_and_read(image_file)
132
+ if not flag:
133
+ img = cv2.imread(image_file)
134
+ if img is None:
135
+ logger.info("error in loading image:{}".format(image_file))
136
+ continue
137
+ valid_image_file_list.append(image_file)
138
+ img_list.append(img)
139
+ try:
140
+ img_list, cls_res, predict_time = text_classifier(img_list)
141
+ except Exception as E:
142
+ logger.info(traceback.format_exc())
143
+ logger.info(E)
144
+ exit()
145
+ for ino in range(len(img_list)):
146
+ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
147
+ cls_res[ino]))
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main(utility.parse_args())
tools/infer/predict_det.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+
17
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(__dir__)
19
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20
+
21
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import time
26
+ import sys
27
+
28
+ import tools.infer.utility as utility
29
+ from ppocr.utils.logging import get_logger
30
+ from ppocr.utils.utility import get_image_file_list, check_and_read
31
+ from ppocr.data import create_operators, transform
32
+ from ppocr.postprocess import build_post_process
33
+ import json
34
+ logger = get_logger()
35
+
36
+
37
+ class TextDetector(object):
38
+ def __init__(self, args):
39
+ self.args = args
40
+ self.det_algorithm = args.det_algorithm
41
+ self.use_onnx = args.use_onnx
42
+ pre_process_list = [{
43
+ 'DetResizeForTest': {
44
+ 'limit_side_len': args.det_limit_side_len,
45
+ 'limit_type': args.det_limit_type,
46
+ }
47
+ }, {
48
+ 'NormalizeImage': {
49
+ 'std': [0.229, 0.224, 0.225],
50
+ 'mean': [0.485, 0.456, 0.406],
51
+ 'scale': '1./255.',
52
+ 'order': 'hwc'
53
+ }
54
+ }, {
55
+ 'ToCHWImage': None
56
+ }, {
57
+ 'KeepKeys': {
58
+ 'keep_keys': ['image', 'shape']
59
+ }
60
+ }]
61
+ postprocess_params = {}
62
+ if self.det_algorithm == "DB":
63
+ postprocess_params['name'] = 'DBPostProcess'
64
+ postprocess_params["thresh"] = args.det_db_thresh
65
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
66
+ postprocess_params["max_candidates"] = 1000
67
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
68
+ postprocess_params["use_dilation"] = args.use_dilation
69
+ postprocess_params["score_mode"] = args.det_db_score_mode
70
+ postprocess_params["box_type"] = args.det_box_type
71
+ elif self.det_algorithm == "DB++":
72
+ postprocess_params['name'] = 'DBPostProcess'
73
+ postprocess_params["thresh"] = args.det_db_thresh
74
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
75
+ postprocess_params["max_candidates"] = 1000
76
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
77
+ postprocess_params["use_dilation"] = args.use_dilation
78
+ postprocess_params["score_mode"] = args.det_db_score_mode
79
+ postprocess_params["box_type"] = args.det_box_type
80
+ pre_process_list[1] = {
81
+ 'NormalizeImage': {
82
+ 'std': [1.0, 1.0, 1.0],
83
+ 'mean':
84
+ [0.48109378172549, 0.45752457890196, 0.40787054090196],
85
+ 'scale': '1./255.',
86
+ 'order': 'hwc'
87
+ }
88
+ }
89
+ elif self.det_algorithm == "EAST":
90
+ postprocess_params['name'] = 'EASTPostProcess'
91
+ postprocess_params["score_thresh"] = args.det_east_score_thresh
92
+ postprocess_params["cover_thresh"] = args.det_east_cover_thresh
93
+ postprocess_params["nms_thresh"] = args.det_east_nms_thresh
94
+ elif self.det_algorithm == "SAST":
95
+ pre_process_list[0] = {
96
+ 'DetResizeForTest': {
97
+ 'resize_long': args.det_limit_side_len
98
+ }
99
+ }
100
+ postprocess_params['name'] = 'SASTPostProcess'
101
+ postprocess_params["score_thresh"] = args.det_sast_score_thresh
102
+ postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
103
+
104
+ if args.det_box_type == 'poly':
105
+ postprocess_params["sample_pts_num"] = 6
106
+ postprocess_params["expand_scale"] = 1.2
107
+ postprocess_params["shrink_ratio_of_width"] = 0.2
108
+ else:
109
+ postprocess_params["sample_pts_num"] = 2
110
+ postprocess_params["expand_scale"] = 1.0
111
+ postprocess_params["shrink_ratio_of_width"] = 0.3
112
+
113
+ elif self.det_algorithm == "PSE":
114
+ postprocess_params['name'] = 'PSEPostProcess'
115
+ postprocess_params["thresh"] = args.det_pse_thresh
116
+ postprocess_params["box_thresh"] = args.det_pse_box_thresh
117
+ postprocess_params["min_area"] = args.det_pse_min_area
118
+ postprocess_params["box_type"] = args.det_box_type
119
+ postprocess_params["scale"] = args.det_pse_scale
120
+ elif self.det_algorithm == "FCE":
121
+ pre_process_list[0] = {
122
+ 'DetResizeForTest': {
123
+ 'rescale_img': [1080, 736]
124
+ }
125
+ }
126
+ postprocess_params['name'] = 'FCEPostProcess'
127
+ postprocess_params["scales"] = args.scales
128
+ postprocess_params["alpha"] = args.alpha
129
+ postprocess_params["beta"] = args.beta
130
+ postprocess_params["fourier_degree"] = args.fourier_degree
131
+ postprocess_params["box_type"] = args.det_box_type
132
+ elif self.det_algorithm == "CT":
133
+ pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
134
+ postprocess_params['name'] = 'CTPostProcess'
135
+ else:
136
+ logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
137
+ sys.exit(0)
138
+
139
+ self.preprocess_op = create_operators(pre_process_list)
140
+ self.postprocess_op = build_post_process(postprocess_params)
141
+ self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
142
+ args, 'det', logger)
143
+
144
+ if self.use_onnx:
145
+ img_h, img_w = self.input_tensor.shape[2:]
146
+ if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
147
+ pre_process_list[0] = {
148
+ 'DetResizeForTest': {
149
+ 'image_shape': [img_h, img_w]
150
+ }
151
+ }
152
+ self.preprocess_op = create_operators(pre_process_list)
153
+
154
+ if args.benchmark:
155
+ import auto_log
156
+ pid = os.getpid()
157
+ gpu_id = utility.get_infer_gpuid()
158
+ self.autolog = auto_log.AutoLogger(
159
+ model_name="det",
160
+ model_precision=args.precision,
161
+ batch_size=1,
162
+ data_shape="dynamic",
163
+ save_path=None,
164
+ inference_config=self.config,
165
+ pids=pid,
166
+ process_name=None,
167
+ gpu_ids=gpu_id if args.use_gpu else None,
168
+ time_keys=[
169
+ 'preprocess_time', 'inference_time', 'postprocess_time'
170
+ ],
171
+ warmup=2,
172
+ logger=logger)
173
+
174
+ def order_points_clockwise(self, pts):
175
+ rect = np.zeros((4, 2), dtype="float32")
176
+ s = pts.sum(axis=1)
177
+ rect[0] = pts[np.argmin(s)]
178
+ rect[2] = pts[np.argmax(s)]
179
+ tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
180
+ diff = np.diff(np.array(tmp), axis=1)
181
+ rect[1] = tmp[np.argmin(diff)]
182
+ rect[3] = tmp[np.argmax(diff)]
183
+ return rect
184
+
185
+ def clip_det_res(self, points, img_height, img_width):
186
+ for pno in range(points.shape[0]):
187
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
188
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
189
+ return points
190
+
191
+ def filter_tag_det_res(self, dt_boxes, image_shape):
192
+ img_height, img_width = image_shape[0:2]
193
+ dt_boxes_new = []
194
+ for box in dt_boxes:
195
+ if type(box) is list:
196
+ box = np.array(box)
197
+ box = self.order_points_clockwise(box)
198
+ box = self.clip_det_res(box, img_height, img_width)
199
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
200
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
201
+ if rect_width <= 3 or rect_height <= 3:
202
+ continue
203
+ dt_boxes_new.append(box)
204
+ dt_boxes = np.array(dt_boxes_new)
205
+ return dt_boxes
206
+
207
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
208
+ img_height, img_width = image_shape[0:2]
209
+ dt_boxes_new = []
210
+ for box in dt_boxes:
211
+ if type(box) is list:
212
+ box = np.array(box)
213
+ box = self.clip_det_res(box, img_height, img_width)
214
+ dt_boxes_new.append(box)
215
+ dt_boxes = np.array(dt_boxes_new)
216
+ return dt_boxes
217
+
218
+ def __call__(self, img):
219
+ ori_im = img.copy()
220
+ data = {'image': img}
221
+
222
+ st = time.time()
223
+
224
+ if self.args.benchmark:
225
+ self.autolog.times.start()
226
+
227
+ data = transform(data, self.preprocess_op)
228
+ img, shape_list = data
229
+ if img is None:
230
+ return None, 0
231
+ img = np.expand_dims(img, axis=0)
232
+ shape_list = np.expand_dims(shape_list, axis=0)
233
+ img = img.copy()
234
+
235
+ if self.args.benchmark:
236
+ self.autolog.times.stamp()
237
+ if self.use_onnx:
238
+ input_dict = {}
239
+ input_dict[self.input_tensor.name] = img
240
+ outputs = self.predictor.run(self.output_tensors, input_dict)
241
+ else:
242
+ self.input_tensor.copy_from_cpu(img)
243
+ self.predictor.run()
244
+ outputs = []
245
+ for output_tensor in self.output_tensors:
246
+ output = output_tensor.copy_to_cpu()
247
+ outputs.append(output)
248
+ if self.args.benchmark:
249
+ self.autolog.times.stamp()
250
+
251
+ preds = {}
252
+ if self.det_algorithm == "EAST":
253
+ preds['f_geo'] = outputs[0]
254
+ preds['f_score'] = outputs[1]
255
+ elif self.det_algorithm == 'SAST':
256
+ preds['f_border'] = outputs[0]
257
+ preds['f_score'] = outputs[1]
258
+ preds['f_tco'] = outputs[2]
259
+ preds['f_tvo'] = outputs[3]
260
+ elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
261
+ preds['maps'] = outputs[0]
262
+ elif self.det_algorithm == 'FCE':
263
+ for i, output in enumerate(outputs):
264
+ preds['level_{}'.format(i)] = output
265
+ elif self.det_algorithm == "CT":
266
+ preds['maps'] = outputs[0]
267
+ preds['score'] = outputs[1]
268
+ else:
269
+ raise NotImplementedError
270
+
271
+ post_result = self.postprocess_op(preds, shape_list)
272
+ dt_boxes = post_result[0]['points']
273
+
274
+ if self.args.det_box_type == 'poly':
275
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
276
+ else:
277
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
278
+
279
+ if self.args.benchmark:
280
+ self.autolog.times.end(stamp=True)
281
+ et = time.time()
282
+ return dt_boxes, et - st
283
+
284
+
285
+ if __name__ == "__main__":
286
+ args = utility.parse_args()
287
+ image_file_list = get_image_file_list(args.image_dir)
288
+ text_detector = TextDetector(args)
289
+ total_time = 0
290
+ draw_img_save_dir = args.draw_img_save_dir
291
+ os.makedirs(draw_img_save_dir, exist_ok=True)
292
+
293
+ if args.warmup:
294
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
295
+ for i in range(2):
296
+ res = text_detector(img)
297
+
298
+ save_results = []
299
+ for idx, image_file in enumerate(image_file_list):
300
+ img, flag_gif, flag_pdf = check_and_read(image_file)
301
+ if not flag_gif and not flag_pdf:
302
+ img = cv2.imread(image_file)
303
+ if not flag_pdf:
304
+ if img is None:
305
+ logger.debug("error in loading image:{}".format(image_file))
306
+ continue
307
+ imgs = [img]
308
+ else:
309
+ page_num = args.page_num
310
+ if page_num > len(img) or page_num == 0:
311
+ page_num = len(img)
312
+ imgs = img[:page_num]
313
+ for index, img in enumerate(imgs):
314
+ st = time.time()
315
+ dt_boxes, _ = text_detector(img)
316
+ elapse = time.time() - st
317
+ total_time += elapse
318
+ if len(imgs) > 1:
319
+ save_pred = os.path.basename(image_file) + '_' + str(
320
+ index) + "\t" + str(
321
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
322
+ else:
323
+ save_pred = os.path.basename(image_file) + "\t" + str(
324
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
325
+ save_results.append(save_pred)
326
+ logger.info(save_pred)
327
+ if len(imgs) > 1:
328
+ logger.info("{}_{} The predict time of {}: {}".format(
329
+ idx, index, image_file, elapse))
330
+ else:
331
+ logger.info("{} The predict time of {}: {}".format(
332
+ idx, image_file, elapse))
333
+
334
+ src_im = utility.draw_text_det_res(dt_boxes, img)
335
+
336
+ if flag_gif:
337
+ save_file = image_file[:-3] + "png"
338
+ elif flag_pdf:
339
+ save_file = image_file.replace('.pdf',
340
+ '_' + str(index) + '.png')
341
+ else:
342
+ save_file = image_file
343
+ img_path = os.path.join(
344
+ draw_img_save_dir,
345
+ "det_res_{}".format(os.path.basename(save_file)))
346
+ cv2.imwrite(img_path, src_im)
347
+ logger.info("The visualized image saved in {}".format(img_path))
348
+
349
+ with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
350
+ f.writelines(save_results)
351
+ f.close()
352
+ if args.benchmark:
353
+ text_detector.autolog.report()
tools/infer/predict_e2e.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+
17
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(__dir__)
19
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20
+
21
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import time
26
+ import sys
27
+
28
+ import tools.infer.utility as utility
29
+ from ppocr.utils.logging import get_logger
30
+ from ppocr.utils.utility import get_image_file_list, check_and_read
31
+ from ppocr.data import create_operators, transform
32
+ from ppocr.postprocess import build_post_process
33
+
34
+ logger = get_logger()
35
+
36
+
37
+ class TextE2E(object):
38
+ def __init__(self, args):
39
+ self.args = args
40
+ self.e2e_algorithm = args.e2e_algorithm
41
+ self.use_onnx = args.use_onnx
42
+ pre_process_list = [{
43
+ 'E2EResizeForTest': {}
44
+ }, {
45
+ 'NormalizeImage': {
46
+ 'std': [0.229, 0.224, 0.225],
47
+ 'mean': [0.485, 0.456, 0.406],
48
+ 'scale': '1./255.',
49
+ 'order': 'hwc'
50
+ }
51
+ }, {
52
+ 'ToCHWImage': None
53
+ }, {
54
+ 'KeepKeys': {
55
+ 'keep_keys': ['image', 'shape']
56
+ }
57
+ }]
58
+ postprocess_params = {}
59
+ if self.e2e_algorithm == "PGNet":
60
+ pre_process_list[0] = {
61
+ 'E2EResizeForTest': {
62
+ 'max_side_len': args.e2e_limit_side_len,
63
+ 'valid_set': 'totaltext'
64
+ }
65
+ }
66
+ postprocess_params['name'] = 'PGPostProcess'
67
+ postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
68
+ postprocess_params["character_dict_path"] = args.e2e_char_dict_path
69
+ postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
70
+ postprocess_params["mode"] = args.e2e_pgnet_mode
71
+ else:
72
+ logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
73
+ sys.exit(0)
74
+
75
+ self.preprocess_op = create_operators(pre_process_list)
76
+ self.postprocess_op = build_post_process(postprocess_params)
77
+ self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
78
+ args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
79
+ # self.predictor.eval()
80
+
81
+ def clip_det_res(self, points, img_height, img_width):
82
+ for pno in range(points.shape[0]):
83
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
84
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
85
+ return points
86
+
87
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
88
+ img_height, img_width = image_shape[0:2]
89
+ dt_boxes_new = []
90
+ for box in dt_boxes:
91
+ box = self.clip_det_res(box, img_height, img_width)
92
+ dt_boxes_new.append(box)
93
+ dt_boxes = np.array(dt_boxes_new)
94
+ return dt_boxes
95
+
96
+ def __call__(self, img):
97
+
98
+ ori_im = img.copy()
99
+ data = {'image': img}
100
+ data = transform(data, self.preprocess_op)
101
+ img, shape_list = data
102
+ if img is None:
103
+ return None, 0
104
+ img = np.expand_dims(img, axis=0)
105
+ shape_list = np.expand_dims(shape_list, axis=0)
106
+ img = img.copy()
107
+ starttime = time.time()
108
+
109
+ if self.use_onnx:
110
+ input_dict = {}
111
+ input_dict[self.input_tensor.name] = img
112
+ outputs = self.predictor.run(self.output_tensors, input_dict)
113
+ preds = {}
114
+ preds['f_border'] = outputs[0]
115
+ preds['f_char'] = outputs[1]
116
+ preds['f_direction'] = outputs[2]
117
+ preds['f_score'] = outputs[3]
118
+ else:
119
+ self.input_tensor.copy_from_cpu(img)
120
+ self.predictor.run()
121
+ outputs = []
122
+ for output_tensor in self.output_tensors:
123
+ output = output_tensor.copy_to_cpu()
124
+ outputs.append(output)
125
+
126
+ preds = {}
127
+ if self.e2e_algorithm == 'PGNet':
128
+ preds['f_border'] = outputs[0]
129
+ preds['f_char'] = outputs[1]
130
+ preds['f_direction'] = outputs[2]
131
+ preds['f_score'] = outputs[3]
132
+ else:
133
+ raise NotImplementedError
134
+ post_result = self.postprocess_op(preds, shape_list)
135
+ points, strs = post_result['points'], post_result['texts']
136
+ dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
137
+ elapse = time.time() - starttime
138
+ return dt_boxes, strs, elapse
139
+
140
+
141
+ if __name__ == "__main__":
142
+ args = utility.parse_args()
143
+ image_file_list = get_image_file_list(args.image_dir)
144
+ text_detector = TextE2E(args)
145
+ count = 0
146
+ total_time = 0
147
+ draw_img_save = "./inference_results"
148
+ if not os.path.exists(draw_img_save):
149
+ os.makedirs(draw_img_save)
150
+ for image_file in image_file_list:
151
+ img, flag, _ = check_and_read(image_file)
152
+ if not flag:
153
+ img = cv2.imread(image_file)
154
+ if img is None:
155
+ logger.info("error in loading image:{}".format(image_file))
156
+ continue
157
+ points, strs, elapse = text_detector(img)
158
+ if count > 0:
159
+ total_time += elapse
160
+ count += 1
161
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
162
+ src_im = utility.draw_e2e_res(points, strs, image_file)
163
+ img_name_pure = os.path.split(image_file)[-1]
164
+ img_path = os.path.join(draw_img_save,
165
+ "e2e_res_{}".format(img_name_pure))
166
+ cv2.imwrite(img_path, src_im)
167
+ logger.info("The visualized image saved in {}".format(img_path))
168
+ if count > 1:
169
+ logger.info("Avg Time: {}".format(total_time / (count - 1)))
tools/infer/predict_rec.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ from PIL import Image
17
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(__dir__)
19
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20
+
21
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import math
26
+ import time
27
+ import traceback
28
+ import paddle
29
+
30
+ import tools.infer.utility as utility
31
+ from ppocr.postprocess import build_post_process
32
+ from ppocr.utils.logging import get_logger
33
+ from ppocr.utils.utility import get_image_file_list, check_and_read
34
+
35
+ logger = get_logger()
36
+
37
+
38
+ class TextRecognizer(object):
39
+ def __init__(self, args):
40
+ self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
41
+ self.rec_batch_num = args.rec_batch_num
42
+ self.rec_algorithm = args.rec_algorithm
43
+ postprocess_params = {
44
+ 'name': 'CTCLabelDecode',
45
+ "character_dict_path": args.rec_char_dict_path,
46
+ "use_space_char": args.use_space_char
47
+ }
48
+ if self.rec_algorithm == "SRN":
49
+ postprocess_params = {
50
+ 'name': 'SRNLabelDecode',
51
+ "character_dict_path": args.rec_char_dict_path,
52
+ "use_space_char": args.use_space_char
53
+ }
54
+ elif self.rec_algorithm == "RARE":
55
+ postprocess_params = {
56
+ 'name': 'AttnLabelDecode',
57
+ "character_dict_path": args.rec_char_dict_path,
58
+ "use_space_char": args.use_space_char
59
+ }
60
+ elif self.rec_algorithm == 'NRTR':
61
+ postprocess_params = {
62
+ 'name': 'NRTRLabelDecode',
63
+ "character_dict_path": args.rec_char_dict_path,
64
+ "use_space_char": args.use_space_char
65
+ }
66
+ elif self.rec_algorithm == "SAR":
67
+ postprocess_params = {
68
+ 'name': 'SARLabelDecode',
69
+ "character_dict_path": args.rec_char_dict_path,
70
+ "use_space_char": args.use_space_char
71
+ }
72
+ elif self.rec_algorithm == "VisionLAN":
73
+ postprocess_params = {
74
+ 'name': 'VLLabelDecode',
75
+ "character_dict_path": args.rec_char_dict_path,
76
+ "use_space_char": args.use_space_char
77
+ }
78
+ elif self.rec_algorithm == 'ViTSTR':
79
+ postprocess_params = {
80
+ 'name': 'ViTSTRLabelDecode',
81
+ "character_dict_path": args.rec_char_dict_path,
82
+ "use_space_char": args.use_space_char
83
+ }
84
+ elif self.rec_algorithm == 'ABINet':
85
+ postprocess_params = {
86
+ 'name': 'ABINetLabelDecode',
87
+ "character_dict_path": args.rec_char_dict_path,
88
+ "use_space_char": args.use_space_char
89
+ }
90
+ elif self.rec_algorithm == "SPIN":
91
+ postprocess_params = {
92
+ 'name': 'SPINLabelDecode',
93
+ "character_dict_path": args.rec_char_dict_path,
94
+ "use_space_char": args.use_space_char
95
+ }
96
+ elif self.rec_algorithm == "RobustScanner":
97
+ postprocess_params = {
98
+ 'name': 'SARLabelDecode',
99
+ "character_dict_path": args.rec_char_dict_path,
100
+ "use_space_char": args.use_space_char,
101
+ "rm_symbol": True
102
+ }
103
+ elif self.rec_algorithm == 'RFL':
104
+ postprocess_params = {
105
+ 'name': 'RFLLabelDecode',
106
+ "character_dict_path": None,
107
+ "use_space_char": args.use_space_char
108
+ }
109
+ elif self.rec_algorithm == "PREN":
110
+ postprocess_params = {'name': 'PRENLabelDecode'}
111
+ elif self.rec_algorithm == "CAN":
112
+ self.inverse = args.rec_image_inverse
113
+ postprocess_params = {
114
+ 'name': 'CANLabelDecode',
115
+ "character_dict_path": args.rec_char_dict_path,
116
+ "use_space_char": args.use_space_char
117
+ }
118
+ self.postprocess_op = build_post_process(postprocess_params)
119
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
120
+ utility.create_predictor(args, 'rec', logger)
121
+ self.benchmark = args.benchmark
122
+ self.use_onnx = args.use_onnx
123
+ if args.benchmark:
124
+ import auto_log
125
+ pid = os.getpid()
126
+ gpu_id = utility.get_infer_gpuid()
127
+ self.autolog = auto_log.AutoLogger(
128
+ model_name="rec",
129
+ model_precision=args.precision,
130
+ batch_size=args.rec_batch_num,
131
+ data_shape="dynamic",
132
+ save_path=None, #args.save_log_path,
133
+ inference_config=self.config,
134
+ pids=pid,
135
+ process_name=None,
136
+ gpu_ids=gpu_id if args.use_gpu else None,
137
+ time_keys=[
138
+ 'preprocess_time', 'inference_time', 'postprocess_time'
139
+ ],
140
+ warmup=0,
141
+ logger=logger)
142
+
143
+ def resize_norm_img(self, img, max_wh_ratio):
144
+ imgC, imgH, imgW = self.rec_image_shape
145
+ if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
146
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
147
+ # return padding_im
148
+ image_pil = Image.fromarray(np.uint8(img))
149
+ if self.rec_algorithm == 'ViTSTR':
150
+ img = image_pil.resize([imgW, imgH], Image.BICUBIC)
151
+ else:
152
+ img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
153
+ img = np.array(img)
154
+ norm_img = np.expand_dims(img, -1)
155
+ norm_img = norm_img.transpose((2, 0, 1))
156
+ if self.rec_algorithm == 'ViTSTR':
157
+ norm_img = norm_img.astype(np.float32) / 255.
158
+ else:
159
+ norm_img = norm_img.astype(np.float32) / 128. - 1.
160
+ return norm_img
161
+ elif self.rec_algorithm == 'RFL':
162
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
163
+ resized_image = cv2.resize(
164
+ img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
165
+ resized_image = resized_image.astype('float32')
166
+ resized_image = resized_image / 255
167
+ resized_image = resized_image[np.newaxis, :]
168
+ resized_image -= 0.5
169
+ resized_image /= 0.5
170
+ return resized_image
171
+
172
+ assert imgC == img.shape[2]
173
+ imgW = int((imgH * max_wh_ratio))
174
+ if self.use_onnx:
175
+ w = self.input_tensor.shape[3:][0]
176
+ if w is not None and w > 0:
177
+ imgW = w
178
+
179
+ h, w = img.shape[:2]
180
+ ratio = w / float(h)
181
+ if math.ceil(imgH * ratio) > imgW:
182
+ resized_w = imgW
183
+ else:
184
+ resized_w = int(math.ceil(imgH * ratio))
185
+ if self.rec_algorithm == 'RARE':
186
+ if resized_w > self.rec_image_shape[2]:
187
+ resized_w = self.rec_image_shape[2]
188
+ imgW = self.rec_image_shape[2]
189
+ resized_image = cv2.resize(img, (resized_w, imgH))
190
+ resized_image = resized_image.astype('float32')
191
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
192
+ resized_image -= 0.5
193
+ resized_image /= 0.5
194
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
195
+ padding_im[:, :, 0:resized_w] = resized_image
196
+ return padding_im
197
+
198
+ def resize_norm_img_vl(self, img, image_shape):
199
+
200
+ imgC, imgH, imgW = image_shape
201
+ img = img[:, :, ::-1] # bgr2rgb
202
+ resized_image = cv2.resize(
203
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
204
+ resized_image = resized_image.astype('float32')
205
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
206
+ return resized_image
207
+
208
+ def resize_norm_img_srn(self, img, image_shape):
209
+ imgC, imgH, imgW = image_shape
210
+
211
+ img_black = np.zeros((imgH, imgW))
212
+ im_hei = img.shape[0]
213
+ im_wid = img.shape[1]
214
+
215
+ if im_wid <= im_hei * 1:
216
+ img_new = cv2.resize(img, (imgH * 1, imgH))
217
+ elif im_wid <= im_hei * 2:
218
+ img_new = cv2.resize(img, (imgH * 2, imgH))
219
+ elif im_wid <= im_hei * 3:
220
+ img_new = cv2.resize(img, (imgH * 3, imgH))
221
+ else:
222
+ img_new = cv2.resize(img, (imgW, imgH))
223
+
224
+ img_np = np.asarray(img_new)
225
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
226
+ img_black[:, 0:img_np.shape[1]] = img_np
227
+ img_black = img_black[:, :, np.newaxis]
228
+
229
+ row, col, c = img_black.shape
230
+ c = 1
231
+
232
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
233
+
234
+ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
235
+
236
+ imgC, imgH, imgW = image_shape
237
+ feature_dim = int((imgH / 8) * (imgW / 8))
238
+
239
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
240
+ (feature_dim, 1)).astype('int64')
241
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
242
+ (max_text_length, 1)).astype('int64')
243
+
244
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
245
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
246
+ [-1, 1, max_text_length, max_text_length])
247
+ gsrm_slf_attn_bias1 = np.tile(
248
+ gsrm_slf_attn_bias1,
249
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
250
+
251
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
252
+ [-1, 1, max_text_length, max_text_length])
253
+ gsrm_slf_attn_bias2 = np.tile(
254
+ gsrm_slf_attn_bias2,
255
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
256
+
257
+ encoder_word_pos = encoder_word_pos[np.newaxis, :]
258
+ gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
259
+
260
+ return [
261
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
262
+ gsrm_slf_attn_bias2
263
+ ]
264
+
265
+ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
266
+ norm_img = self.resize_norm_img_srn(img, image_shape)
267
+ norm_img = norm_img[np.newaxis, :]
268
+
269
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
270
+ self.srn_other_inputs(image_shape, num_heads, max_text_length)
271
+
272
+ gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
273
+ gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
274
+ encoder_word_pos = encoder_word_pos.astype(np.int64)
275
+ gsrm_word_pos = gsrm_word_pos.astype(np.int64)
276
+
277
+ return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
278
+ gsrm_slf_attn_bias2)
279
+
280
+ def resize_norm_img_sar(self, img, image_shape,
281
+ width_downsample_ratio=0.25):
282
+ imgC, imgH, imgW_min, imgW_max = image_shape
283
+ h = img.shape[0]
284
+ w = img.shape[1]
285
+ valid_ratio = 1.0
286
+ # make sure new_width is an integral multiple of width_divisor.
287
+ width_divisor = int(1 / width_downsample_ratio)
288
+ # resize
289
+ ratio = w / float(h)
290
+ resize_w = math.ceil(imgH * ratio)
291
+ if resize_w % width_divisor != 0:
292
+ resize_w = round(resize_w / width_divisor) * width_divisor
293
+ if imgW_min is not None:
294
+ resize_w = max(imgW_min, resize_w)
295
+ if imgW_max is not None:
296
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
297
+ resize_w = min(imgW_max, resize_w)
298
+ resized_image = cv2.resize(img, (resize_w, imgH))
299
+ resized_image = resized_image.astype('float32')
300
+ # norm
301
+ if image_shape[0] == 1:
302
+ resized_image = resized_image / 255
303
+ resized_image = resized_image[np.newaxis, :]
304
+ else:
305
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
306
+ resized_image -= 0.5
307
+ resized_image /= 0.5
308
+ resize_shape = resized_image.shape
309
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
310
+ padding_im[:, :, 0:resize_w] = resized_image
311
+ pad_shape = padding_im.shape
312
+
313
+ return padding_im, resize_shape, pad_shape, valid_ratio
314
+
315
+ def resize_norm_img_spin(self, img):
316
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
317
+ # return padding_im
318
+ img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
319
+ img = np.array(img, np.float32)
320
+ img = np.expand_dims(img, -1)
321
+ img = img.transpose((2, 0, 1))
322
+ mean = [127.5]
323
+ std = [127.5]
324
+ mean = np.array(mean, dtype=np.float32)
325
+ std = np.array(std, dtype=np.float32)
326
+ mean = np.float32(mean.reshape(1, -1))
327
+ stdinv = 1 / np.float32(std.reshape(1, -1))
328
+ img -= mean
329
+ img *= stdinv
330
+ return img
331
+
332
+ def resize_norm_img_svtr(self, img, image_shape):
333
+
334
+ imgC, imgH, imgW = image_shape
335
+ resized_image = cv2.resize(
336
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
337
+ resized_image = resized_image.astype('float32')
338
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
339
+ resized_image -= 0.5
340
+ resized_image /= 0.5
341
+ return resized_image
342
+
343
+ def resize_norm_img_abinet(self, img, image_shape):
344
+
345
+ imgC, imgH, imgW = image_shape
346
+
347
+ resized_image = cv2.resize(
348
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
349
+ resized_image = resized_image.astype('float32')
350
+ resized_image = resized_image / 255.
351
+
352
+ mean = np.array([0.485, 0.456, 0.406])
353
+ std = np.array([0.229, 0.224, 0.225])
354
+ resized_image = (
355
+ resized_image - mean[None, None, ...]) / std[None, None, ...]
356
+ resized_image = resized_image.transpose((2, 0, 1))
357
+ resized_image = resized_image.astype('float32')
358
+
359
+ return resized_image
360
+
361
+ def norm_img_can(self, img, image_shape):
362
+
363
+ img = cv2.cvtColor(
364
+ img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
365
+
366
+ if self.inverse:
367
+ img = 255 - img
368
+
369
+ if self.rec_image_shape[0] == 1:
370
+ h, w = img.shape
371
+ _, imgH, imgW = self.rec_image_shape
372
+ if h < imgH or w < imgW:
373
+ padding_h = max(imgH - h, 0)
374
+ padding_w = max(imgW - w, 0)
375
+ img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
376
+ 'constant',
377
+ constant_values=(255))
378
+ img = img_padded
379
+
380
+ img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
381
+ img = img.astype('float32')
382
+
383
+ return img
384
+
385
+ def __call__(self, img_list):
386
+ img_num = len(img_list)
387
+ # Calculate the aspect ratio of all text bars
388
+ width_list = []
389
+ for img in img_list:
390
+ width_list.append(img.shape[1] / float(img.shape[0]))
391
+ # Sorting can speed up the recognition process
392
+ indices = np.argsort(np.array(width_list))
393
+ rec_res = [['', 0.0]] * img_num
394
+ batch_num = self.rec_batch_num
395
+ st = time.time()
396
+ if self.benchmark:
397
+ self.autolog.times.start()
398
+ for beg_img_no in range(0, img_num, batch_num):
399
+ end_img_no = min(img_num, beg_img_no + batch_num)
400
+ norm_img_batch = []
401
+ if self.rec_algorithm == "SRN":
402
+ encoder_word_pos_list = []
403
+ gsrm_word_pos_list = []
404
+ gsrm_slf_attn_bias1_list = []
405
+ gsrm_slf_attn_bias2_list = []
406
+ if self.rec_algorithm == "SAR":
407
+ valid_ratios = []
408
+ imgC, imgH, imgW = self.rec_image_shape[:3]
409
+ max_wh_ratio = imgW / imgH
410
+ # max_wh_ratio = 0
411
+ for ino in range(beg_img_no, end_img_no):
412
+ h, w = img_list[indices[ino]].shape[0:2]
413
+ wh_ratio = w * 1.0 / h
414
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
415
+ for ino in range(beg_img_no, end_img_no):
416
+ if self.rec_algorithm == "SAR":
417
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
418
+ img_list[indices[ino]], self.rec_image_shape)
419
+ norm_img = norm_img[np.newaxis, :]
420
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
421
+ valid_ratios.append(valid_ratio)
422
+ norm_img_batch.append(norm_img)
423
+ elif self.rec_algorithm == "SRN":
424
+ norm_img = self.process_image_srn(
425
+ img_list[indices[ino]], self.rec_image_shape, 8, 25)
426
+ encoder_word_pos_list.append(norm_img[1])
427
+ gsrm_word_pos_list.append(norm_img[2])
428
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
429
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
430
+ norm_img_batch.append(norm_img[0])
431
+ elif self.rec_algorithm == "SVTR":
432
+ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
433
+ self.rec_image_shape)
434
+ norm_img = norm_img[np.newaxis, :]
435
+ norm_img_batch.append(norm_img)
436
+ elif self.rec_algorithm in ["VisionLAN", "PREN"]:
437
+ norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
438
+ self.rec_image_shape)
439
+ norm_img = norm_img[np.newaxis, :]
440
+ norm_img_batch.append(norm_img)
441
+ elif self.rec_algorithm == 'SPIN':
442
+ norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
443
+ norm_img = norm_img[np.newaxis, :]
444
+ norm_img_batch.append(norm_img)
445
+ elif self.rec_algorithm == "ABINet":
446
+ norm_img = self.resize_norm_img_abinet(
447
+ img_list[indices[ino]], self.rec_image_shape)
448
+ norm_img = norm_img[np.newaxis, :]
449
+ norm_img_batch.append(norm_img)
450
+ elif self.rec_algorithm == "RobustScanner":
451
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
452
+ img_list[indices[ino]],
453
+ self.rec_image_shape,
454
+ width_downsample_ratio=0.25)
455
+ norm_img = norm_img[np.newaxis, :]
456
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
457
+ valid_ratios = []
458
+ valid_ratios.append(valid_ratio)
459
+ norm_img_batch.append(norm_img)
460
+ word_positions_list = []
461
+ word_positions = np.array(range(0, 40)).astype('int64')
462
+ word_positions = np.expand_dims(word_positions, axis=0)
463
+ word_positions_list.append(word_positions)
464
+ elif self.rec_algorithm == "CAN":
465
+ norm_img = self.norm_img_can(img_list[indices[ino]],
466
+ max_wh_ratio)
467
+ norm_img = norm_img[np.newaxis, :]
468
+ norm_img_batch.append(norm_img)
469
+ norm_image_mask = np.ones(norm_img.shape, dtype='float32')
470
+ word_label = np.ones([1, 36], dtype='int64')
471
+ norm_img_mask_batch = []
472
+ word_label_list = []
473
+ norm_img_mask_batch.append(norm_image_mask)
474
+ word_label_list.append(word_label)
475
+ else:
476
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
477
+ max_wh_ratio)
478
+ norm_img = norm_img[np.newaxis, :]
479
+ norm_img_batch.append(norm_img)
480
+ norm_img_batch = np.concatenate(norm_img_batch)
481
+ norm_img_batch = norm_img_batch.copy()
482
+ if self.benchmark:
483
+ self.autolog.times.stamp()
484
+
485
+ if self.rec_algorithm == "SRN":
486
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
487
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
488
+ gsrm_slf_attn_bias1_list = np.concatenate(
489
+ gsrm_slf_attn_bias1_list)
490
+ gsrm_slf_attn_bias2_list = np.concatenate(
491
+ gsrm_slf_attn_bias2_list)
492
+
493
+ inputs = [
494
+ norm_img_batch,
495
+ encoder_word_pos_list,
496
+ gsrm_word_pos_list,
497
+ gsrm_slf_attn_bias1_list,
498
+ gsrm_slf_attn_bias2_list,
499
+ ]
500
+ if self.use_onnx:
501
+ input_dict = {}
502
+ input_dict[self.input_tensor.name] = norm_img_batch
503
+ outputs = self.predictor.run(self.output_tensors,
504
+ input_dict)
505
+ preds = {"predict": outputs[2]}
506
+ else:
507
+ input_names = self.predictor.get_input_names()
508
+ for i in range(len(input_names)):
509
+ input_tensor = self.predictor.get_input_handle(
510
+ input_names[i])
511
+ input_tensor.copy_from_cpu(inputs[i])
512
+ self.predictor.run()
513
+ outputs = []
514
+ for output_tensor in self.output_tensors:
515
+ output = output_tensor.copy_to_cpu()
516
+ outputs.append(output)
517
+ if self.benchmark:
518
+ self.autolog.times.stamp()
519
+ preds = {"predict": outputs[2]}
520
+ elif self.rec_algorithm == "SAR":
521
+ valid_ratios = np.concatenate(valid_ratios)
522
+ inputs = [
523
+ norm_img_batch,
524
+ np.array(
525
+ [valid_ratios], dtype=np.float32),
526
+ ]
527
+ if self.use_onnx:
528
+ input_dict = {}
529
+ input_dict[self.input_tensor.name] = norm_img_batch
530
+ outputs = self.predictor.run(self.output_tensors,
531
+ input_dict)
532
+ preds = outputs[0]
533
+ else:
534
+ input_names = self.predictor.get_input_names()
535
+ for i in range(len(input_names)):
536
+ input_tensor = self.predictor.get_input_handle(
537
+ input_names[i])
538
+ input_tensor.copy_from_cpu(inputs[i])
539
+ self.predictor.run()
540
+ outputs = []
541
+ for output_tensor in self.output_tensors:
542
+ output = output_tensor.copy_to_cpu()
543
+ outputs.append(output)
544
+ if self.benchmark:
545
+ self.autolog.times.stamp()
546
+ preds = outputs[0]
547
+ elif self.rec_algorithm == "RobustScanner":
548
+ valid_ratios = np.concatenate(valid_ratios)
549
+ word_positions_list = np.concatenate(word_positions_list)
550
+ inputs = [norm_img_batch, valid_ratios, word_positions_list]
551
+
552
+ if self.use_onnx:
553
+ input_dict = {}
554
+ input_dict[self.input_tensor.name] = norm_img_batch
555
+ outputs = self.predictor.run(self.output_tensors,
556
+ input_dict)
557
+ preds = outputs[0]
558
+ else:
559
+ input_names = self.predictor.get_input_names()
560
+ for i in range(len(input_names)):
561
+ input_tensor = self.predictor.get_input_handle(
562
+ input_names[i])
563
+ input_tensor.copy_from_cpu(inputs[i])
564
+ self.predictor.run()
565
+ outputs = []
566
+ for output_tensor in self.output_tensors:
567
+ output = output_tensor.copy_to_cpu()
568
+ outputs.append(output)
569
+ if self.benchmark:
570
+ self.autolog.times.stamp()
571
+ preds = outputs[0]
572
+ elif self.rec_algorithm == "CAN":
573
+ norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
574
+ word_label_list = np.concatenate(word_label_list)
575
+ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
576
+ if self.use_onnx:
577
+ input_dict = {}
578
+ input_dict[self.input_tensor.name] = norm_img_batch
579
+ outputs = self.predictor.run(self.output_tensors,
580
+ input_dict)
581
+ preds = outputs
582
+ else:
583
+ input_names = self.predictor.get_input_names()
584
+ input_tensor = []
585
+ for i in range(len(input_names)):
586
+ input_tensor_i = self.predictor.get_input_handle(
587
+ input_names[i])
588
+ input_tensor_i.copy_from_cpu(inputs[i])
589
+ input_tensor.append(input_tensor_i)
590
+ self.input_tensor = input_tensor
591
+ self.predictor.run()
592
+ outputs = []
593
+ for output_tensor in self.output_tensors:
594
+ output = output_tensor.copy_to_cpu()
595
+ outputs.append(output)
596
+ if self.benchmark:
597
+ self.autolog.times.stamp()
598
+ preds = outputs
599
+ else:
600
+ if self.use_onnx:
601
+ input_dict = {}
602
+ input_dict[self.input_tensor.name] = norm_img_batch
603
+ outputs = self.predictor.run(self.output_tensors,
604
+ input_dict)
605
+ preds = outputs[0]
606
+ else:
607
+ self.input_tensor.copy_from_cpu(norm_img_batch)
608
+ self.predictor.run()
609
+ outputs = []
610
+ for output_tensor in self.output_tensors:
611
+ output = output_tensor.copy_to_cpu()
612
+ outputs.append(output)
613
+ if self.benchmark:
614
+ self.autolog.times.stamp()
615
+ if len(outputs) != 1:
616
+ preds = outputs
617
+ else:
618
+ preds = outputs[0]
619
+ rec_result = self.postprocess_op(preds)
620
+ for rno in range(len(rec_result)):
621
+ rec_res[indices[beg_img_no + rno]] = rec_result[rno]
622
+ if self.benchmark:
623
+ self.autolog.times.end(stamp=True)
624
+ return rec_res, time.time() - st
625
+
626
+
627
+ def main(args):
628
+ image_file_list = get_image_file_list(args.image_dir)
629
+ text_recognizer = TextRecognizer(args)
630
+ valid_image_file_list = []
631
+ img_list = []
632
+
633
+ logger.info(
634
+ "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
635
+ "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
636
+ )
637
+ # warmup 2 times
638
+ if args.warmup:
639
+ img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
640
+ for i in range(2):
641
+ res = text_recognizer([img] * int(args.rec_batch_num))
642
+
643
+ for image_file in image_file_list:
644
+ img, flag, _ = check_and_read(image_file)
645
+ if not flag:
646
+ img = cv2.imread(image_file)
647
+ if img is None:
648
+ logger.info("error in loading image:{}".format(image_file))
649
+ continue
650
+ valid_image_file_list.append(image_file)
651
+ img_list.append(img)
652
+ try:
653
+ rec_res, _ = text_recognizer(img_list)
654
+
655
+ except Exception as E:
656
+ logger.info(traceback.format_exc())
657
+ logger.info(E)
658
+ exit()
659
+ for ino in range(len(img_list)):
660
+ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
661
+ rec_res[ino]))
662
+ if args.benchmark:
663
+ text_recognizer.autolog.report()
664
+
665
+
666
+ if __name__ == "__main__":
667
+ main(utility.parse_args())
tools/infer/predict_sr.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ from PIL import Image
17
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.insert(0, __dir__)
19
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
20
+
21
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import math
26
+ import time
27
+ import traceback
28
+ import paddle
29
+
30
+ import tools.infer.utility as utility
31
+ from ppocr.postprocess import build_post_process
32
+ from ppocr.utils.logging import get_logger
33
+ from ppocr.utils.utility import get_image_file_list, check_and_read
34
+
35
+ logger = get_logger()
36
+
37
+
38
+ class TextSR(object):
39
+ def __init__(self, args):
40
+ self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
41
+ self.sr_batch_num = args.sr_batch_num
42
+
43
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
44
+ utility.create_predictor(args, 'sr', logger)
45
+ self.benchmark = args.benchmark
46
+ if args.benchmark:
47
+ import auto_log
48
+ pid = os.getpid()
49
+ gpu_id = utility.get_infer_gpuid()
50
+ self.autolog = auto_log.AutoLogger(
51
+ model_name="sr",
52
+ model_precision=args.precision,
53
+ batch_size=args.sr_batch_num,
54
+ data_shape="dynamic",
55
+ save_path=None, #args.save_log_path,
56
+ inference_config=self.config,
57
+ pids=pid,
58
+ process_name=None,
59
+ gpu_ids=gpu_id if args.use_gpu else None,
60
+ time_keys=[
61
+ 'preprocess_time', 'inference_time', 'postprocess_time'
62
+ ],
63
+ warmup=0,
64
+ logger=logger)
65
+
66
+ def resize_norm_img(self, img):
67
+ imgC, imgH, imgW = self.sr_image_shape
68
+ img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
69
+ img_numpy = np.array(img).astype("float32")
70
+ img_numpy = img_numpy.transpose((2, 0, 1)) / 255
71
+ return img_numpy
72
+
73
+ def __call__(self, img_list):
74
+ img_num = len(img_list)
75
+ batch_num = self.sr_batch_num
76
+ st = time.time()
77
+ st = time.time()
78
+ all_result = [] * img_num
79
+ if self.benchmark:
80
+ self.autolog.times.start()
81
+ for beg_img_no in range(0, img_num, batch_num):
82
+ end_img_no = min(img_num, beg_img_no + batch_num)
83
+ norm_img_batch = []
84
+ imgC, imgH, imgW = self.sr_image_shape
85
+ for ino in range(beg_img_no, end_img_no):
86
+ norm_img = self.resize_norm_img(img_list[ino])
87
+ norm_img = norm_img[np.newaxis, :]
88
+ norm_img_batch.append(norm_img)
89
+
90
+ norm_img_batch = np.concatenate(norm_img_batch)
91
+ norm_img_batch = norm_img_batch.copy()
92
+ if self.benchmark:
93
+ self.autolog.times.stamp()
94
+ self.input_tensor.copy_from_cpu(norm_img_batch)
95
+ self.predictor.run()
96
+ outputs = []
97
+ for output_tensor in self.output_tensors:
98
+ output = output_tensor.copy_to_cpu()
99
+ outputs.append(output)
100
+ if len(outputs) != 1:
101
+ preds = outputs
102
+ else:
103
+ preds = outputs[0]
104
+ all_result.append(outputs)
105
+ if self.benchmark:
106
+ self.autolog.times.end(stamp=True)
107
+ return all_result, time.time() - st
108
+
109
+
110
+ def main(args):
111
+ image_file_list = get_image_file_list(args.image_dir)
112
+ text_recognizer = TextSR(args)
113
+ valid_image_file_list = []
114
+ img_list = []
115
+
116
+ # warmup 2 times
117
+ if args.warmup:
118
+ img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
119
+ for i in range(2):
120
+ res = text_recognizer([img] * int(args.sr_batch_num))
121
+
122
+ for image_file in image_file_list:
123
+ img, flag, _ = check_and_read(image_file)
124
+ if not flag:
125
+ img = Image.open(image_file).convert("RGB")
126
+ if img is None:
127
+ logger.info("error in loading image:{}".format(image_file))
128
+ continue
129
+ valid_image_file_list.append(image_file)
130
+ img_list.append(img)
131
+ try:
132
+ preds, _ = text_recognizer(img_list)
133
+ for beg_no in range(len(preds)):
134
+ sr_img = preds[beg_no][1]
135
+ lr_img = preds[beg_no][0]
136
+ for i in (range(sr_img.shape[0])):
137
+ fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
138
+ fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
139
+ img_name_pure = os.path.split(valid_image_file_list[
140
+ beg_no * args.sr_batch_num + i])[-1]
141
+ cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
142
+ fm_sr[:, :, ::-1])
143
+ logger.info("The visualized image saved in infer_result/sr_{}".
144
+ format(img_name_pure))
145
+
146
+ except Exception as E:
147
+ logger.info(traceback.format_exc())
148
+ logger.info(E)
149
+ exit()
150
+ if args.benchmark:
151
+ text_recognizer.autolog.report()
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main(utility.parse_args())
tools/infer/predict_system.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import subprocess
17
+
18
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.append(__dir__)
20
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
21
+
22
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
23
+
24
+ import cv2
25
+ import copy
26
+ import numpy as np
27
+ import json
28
+ import time
29
+ import logging
30
+ from PIL import Image
31
+ import tools.infer.utility as utility
32
+ import tools.infer.predict_rec as predict_rec
33
+ import tools.infer.predict_det as predict_det
34
+ import tools.infer.predict_cls as predict_cls
35
+ from ppocr.utils.utility import get_image_file_list, check_and_read
36
+ from ppocr.utils.logging import get_logger
37
+ from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
38
+ logger = get_logger()
39
+
40
+
41
+ class TextSystem(object):
42
+ def __init__(self, args):
43
+ if not args.show_log:
44
+ logger.setLevel(logging.INFO)
45
+
46
+ self.text_detector = predict_det.TextDetector(args)
47
+ self.text_recognizer = predict_rec.TextRecognizer(args)
48
+ self.use_angle_cls = args.use_angle_cls
49
+ self.drop_score = args.drop_score
50
+ if self.use_angle_cls:
51
+ self.text_classifier = predict_cls.TextClassifier(args)
52
+
53
+ self.args = args
54
+ self.crop_image_res_index = 0
55
+
56
+ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
57
+ os.makedirs(output_dir, exist_ok=True)
58
+ bbox_num = len(img_crop_list)
59
+ for bno in range(bbox_num):
60
+ cv2.imwrite(
61
+ os.path.join(output_dir,
62
+ f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
63
+ img_crop_list[bno])
64
+ logger.debug(f"{bno}, {rec_res[bno]}")
65
+ self.crop_image_res_index += bbox_num
66
+
67
+ def __call__(self, img, cls=True):
68
+ time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
69
+ start = time.time()
70
+ ori_im = img.copy()
71
+ dt_boxes, elapse = self.text_detector(img)
72
+ time_dict['det'] = elapse
73
+ logger.debug("dt_boxes num : {}, elapse : {}".format(
74
+ len(dt_boxes), elapse))
75
+ if dt_boxes is None:
76
+ return None, None
77
+ img_crop_list = []
78
+
79
+ dt_boxes = sorted_boxes(dt_boxes)
80
+
81
+ for bno in range(len(dt_boxes)):
82
+ tmp_box = copy.deepcopy(dt_boxes[bno])
83
+ if self.args.det_box_type == "quad":
84
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
85
+ else:
86
+ img_crop = get_minarea_rect_crop(ori_im, tmp_box)
87
+ img_crop_list.append(img_crop)
88
+ if self.use_angle_cls and cls:
89
+ img_crop_list, angle_list, elapse = self.text_classifier(
90
+ img_crop_list)
91
+ time_dict['cls'] = elapse
92
+ logger.debug("cls num : {}, elapse : {}".format(
93
+ len(img_crop_list), elapse))
94
+
95
+ rec_res, elapse = self.text_recognizer(img_crop_list)
96
+ time_dict['rec'] = elapse
97
+ logger.debug("rec_res num : {}, elapse : {}".format(
98
+ len(rec_res), elapse))
99
+ if self.args.save_crop_res:
100
+ self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
101
+ rec_res)
102
+ filter_boxes, filter_rec_res = [], []
103
+ for box, rec_result in zip(dt_boxes, rec_res):
104
+ text, score = rec_result
105
+ if score >= self.drop_score:
106
+ filter_boxes.append(box)
107
+ filter_rec_res.append(rec_result)
108
+ end = time.time()
109
+ time_dict['all'] = end - start
110
+ return filter_boxes, filter_rec_res, time_dict
111
+
112
+
113
+ def sorted_boxes(dt_boxes):
114
+ """
115
+ Sort text boxes in order from top to bottom, left to right
116
+ args:
117
+ dt_boxes(array):detected text boxes with shape [4, 2]
118
+ return:
119
+ sorted boxes(array) with shape [4, 2]
120
+ """
121
+ num_boxes = dt_boxes.shape[0]
122
+ sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
123
+ _boxes = list(sorted_boxes)
124
+
125
+ for i in range(num_boxes - 1):
126
+ for j in range(i, -1, -1):
127
+ if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
128
+ (_boxes[j + 1][0][0] < _boxes[j][0][0]):
129
+ tmp = _boxes[j]
130
+ _boxes[j] = _boxes[j + 1]
131
+ _boxes[j + 1] = tmp
132
+ else:
133
+ break
134
+ return _boxes
135
+
136
+
137
+ def main(args):
138
+ image_file_list = get_image_file_list(args.image_dir)
139
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
140
+ text_sys = TextSystem(args)
141
+ is_visualize = True
142
+ font_path = args.vis_font_path
143
+ drop_score = args.drop_score
144
+ draw_img_save_dir = args.draw_img_save_dir
145
+ os.makedirs(draw_img_save_dir, exist_ok=True)
146
+ save_results = []
147
+
148
+ logger.info(
149
+ "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
150
+ "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
151
+ )
152
+
153
+ # warm up 10 times
154
+ if args.warmup:
155
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
156
+ for i in range(10):
157
+ res = text_sys(img)
158
+
159
+ total_time = 0
160
+ cpu_mem, gpu_mem, gpu_util = 0, 0, 0
161
+ _st = time.time()
162
+ count = 0
163
+ for idx, image_file in enumerate(image_file_list):
164
+
165
+ img, flag_gif, flag_pdf = check_and_read(image_file)
166
+ if not flag_gif and not flag_pdf:
167
+ img = cv2.imread(image_file)
168
+ if not flag_pdf:
169
+ if img is None:
170
+ logger.debug("error in loading image:{}".format(image_file))
171
+ continue
172
+ imgs = [img]
173
+ else:
174
+ page_num = args.page_num
175
+ if page_num > len(img) or page_num == 0:
176
+ page_num = len(img)
177
+ imgs = img[:page_num]
178
+ for index, img in enumerate(imgs):
179
+ starttime = time.time()
180
+ dt_boxes, rec_res, time_dict = text_sys(img)
181
+ elapse = time.time() - starttime
182
+ total_time += elapse
183
+ if len(imgs) > 1:
184
+ logger.debug(
185
+ str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
186
+ % (image_file, elapse))
187
+ else:
188
+ logger.debug(
189
+ str(idx) + " Predict time of %s: %.3fs" % (image_file,
190
+ elapse))
191
+ for text, score in rec_res:
192
+ logger.debug("{}, {:.3f}".format(text, score))
193
+
194
+ res = [{
195
+ "transcription": rec_res[i][0],
196
+ "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
197
+ } for i in range(len(dt_boxes))]
198
+ if len(imgs) > 1:
199
+ save_pred = os.path.basename(image_file) + '_' + str(
200
+ index) + "\t" + json.dumps(
201
+ res, ensure_ascii=False) + "\n"
202
+ else:
203
+ save_pred = os.path.basename(image_file) + "\t" + json.dumps(
204
+ res, ensure_ascii=False) + "\n"
205
+ save_results.append(save_pred)
206
+
207
+ if is_visualize:
208
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
209
+ boxes = dt_boxes
210
+ txts = [rec_res[i][0] for i in range(len(rec_res))]
211
+ scores = [rec_res[i][1] for i in range(len(rec_res))]
212
+
213
+ draw_img = draw_ocr_box_txt(
214
+ image,
215
+ boxes,
216
+ txts,
217
+ scores,
218
+ drop_score=drop_score,
219
+ font_path=font_path)
220
+ if flag_gif:
221
+ save_file = image_file[:-3] + "png"
222
+ elif flag_pdf:
223
+ save_file = image_file.replace('.pdf',
224
+ '_' + str(index) + '.png')
225
+ else:
226
+ save_file = image_file
227
+ cv2.imwrite(
228
+ os.path.join(draw_img_save_dir,
229
+ os.path.basename(save_file)),
230
+ draw_img[:, :, ::-1])
231
+ logger.debug("The visualized image saved in {}".format(
232
+ os.path.join(draw_img_save_dir, os.path.basename(
233
+ save_file))))
234
+
235
+ logger.info("The predict total time is {}".format(time.time() - _st))
236
+ if args.benchmark:
237
+ text_sys.text_detector.autolog.report()
238
+ text_sys.text_recognizer.autolog.report()
239
+
240
+ with open(
241
+ os.path.join(draw_img_save_dir, "system_results.txt"),
242
+ 'w',
243
+ encoding='utf-8') as f:
244
+ f.writelines(save_results)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ args = utility.parse_args()
249
+ if args.use_mp:
250
+ p_list = []
251
+ total_process_num = args.total_process_num
252
+ for process_id in range(total_process_num):
253
+ cmd = [sys.executable, "-u"] + sys.argv + [
254
+ "--process_id={}".format(process_id),
255
+ "--use_mp={}".format(False)
256
+ ]
257
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
258
+ p_list.append(p)
259
+ for p in p_list:
260
+ p.wait()
261
+ else:
262
+ main(args)
tools/infer/utility.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import sys
18
+ import platform
19
+ import cv2
20
+ import numpy as np
21
+ import paddle
22
+ from PIL import Image, ImageDraw, ImageFont
23
+ import math
24
+ from paddle import inference
25
+ import time
26
+ import random
27
+ from ppocr.utils.logging import get_logger
28
+
29
+
30
+ def str2bool(v):
31
+ return v.lower() in ("true", "t", "1")
32
+
33
+
34
+ def init_args():
35
+ parser = argparse.ArgumentParser()
36
+ # params for prediction engine
37
+ parser.add_argument("--use_gpu", type=str2bool, default=True)
38
+ parser.add_argument("--use_xpu", type=str2bool, default=False)
39
+ parser.add_argument("--use_npu", type=str2bool, default=False)
40
+ parser.add_argument("--ir_optim", type=str2bool, default=True)
41
+ parser.add_argument("--use_tensorrt", type=str2bool, default=False)
42
+ parser.add_argument("--min_subgraph_size", type=int, default=15)
43
+ parser.add_argument("--precision", type=str, default="fp32")
44
+ parser.add_argument("--gpu_mem", type=int, default=500)
45
+ parser.add_argument("--gpu_id", type=int, default=0)
46
+
47
+ # params for text detector
48
+ parser.add_argument("--image_dir", type=str)
49
+ parser.add_argument("--page_num", type=int, default=0)
50
+ parser.add_argument("--det_algorithm", type=str, default='DB')
51
+ parser.add_argument("--det_model_dir", type=str)
52
+ parser.add_argument("--det_limit_side_len", type=float, default=960)
53
+ parser.add_argument("--det_limit_type", type=str, default='max')
54
+ parser.add_argument("--det_box_type", type=str, default='quad')
55
+
56
+ # DB parmas
57
+ parser.add_argument("--det_db_thresh", type=float, default=0.3)
58
+ parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
59
+ parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
60
+ parser.add_argument("--max_batch_size", type=int, default=10)
61
+ parser.add_argument("--use_dilation", type=str2bool, default=False)
62
+ parser.add_argument("--det_db_score_mode", type=str, default="fast")
63
+
64
+ # EAST parmas
65
+ parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
66
+ parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
67
+ parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
68
+
69
+ # SAST parmas
70
+ parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
71
+ parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
72
+
73
+ # PSE parmas
74
+ parser.add_argument("--det_pse_thresh", type=float, default=0)
75
+ parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
76
+ parser.add_argument("--det_pse_min_area", type=float, default=16)
77
+ parser.add_argument("--det_pse_scale", type=int, default=1)
78
+
79
+ # FCE parmas
80
+ parser.add_argument("--scales", type=list, default=[8, 16, 32])
81
+ parser.add_argument("--alpha", type=float, default=1.0)
82
+ parser.add_argument("--beta", type=float, default=1.0)
83
+ parser.add_argument("--fourier_degree", type=int, default=5)
84
+
85
+ # params for text recognizer
86
+ parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
87
+ parser.add_argument("--rec_model_dir", type=str)
88
+ parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
89
+ parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
90
+ parser.add_argument("--rec_batch_num", type=int, default=6)
91
+ parser.add_argument("--max_text_length", type=int, default=25)
92
+ parser.add_argument(
93
+ "--rec_char_dict_path",
94
+ type=str,
95
+ default="./ppocr/utils/ppocr_keys_v1.txt")
96
+ parser.add_argument("--use_space_char", type=str2bool, default=True)
97
+ parser.add_argument(
98
+ "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
99
+ parser.add_argument("--drop_score", type=float, default=0.5)
100
+
101
+ # params for e2e
102
+ parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
103
+ parser.add_argument("--e2e_model_dir", type=str)
104
+ parser.add_argument("--e2e_limit_side_len", type=float, default=768)
105
+ parser.add_argument("--e2e_limit_type", type=str, default='max')
106
+
107
+ # PGNet parmas
108
+ parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
109
+ parser.add_argument(
110
+ "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
111
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
112
+ parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
113
+
114
+ # params for text classifier
115
+ parser.add_argument("--use_angle_cls", type=str2bool, default=False)
116
+ parser.add_argument("--cls_model_dir", type=str)
117
+ parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
118
+ parser.add_argument("--label_list", type=list, default=['0', '180'])
119
+ parser.add_argument("--cls_batch_num", type=int, default=6)
120
+ parser.add_argument("--cls_thresh", type=float, default=0.9)
121
+
122
+ parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
123
+ parser.add_argument("--cpu_threads", type=int, default=10)
124
+ parser.add_argument("--use_pdserving", type=str2bool, default=False)
125
+ parser.add_argument("--warmup", type=str2bool, default=False)
126
+
127
+ # SR parmas
128
+ parser.add_argument("--sr_model_dir", type=str)
129
+ parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
130
+ parser.add_argument("--sr_batch_num", type=int, default=1)
131
+
132
+ #
133
+ parser.add_argument(
134
+ "--draw_img_save_dir", type=str, default="./inference_results")
135
+ parser.add_argument("--save_crop_res", type=str2bool, default=False)
136
+ parser.add_argument("--crop_res_save_dir", type=str, default="./output")
137
+
138
+ # multi-process
139
+ parser.add_argument("--use_mp", type=str2bool, default=False)
140
+ parser.add_argument("--total_process_num", type=int, default=1)
141
+ parser.add_argument("--process_id", type=int, default=0)
142
+
143
+ parser.add_argument("--benchmark", type=str2bool, default=False)
144
+ parser.add_argument("--save_log_path", type=str, default="./log_output/")
145
+
146
+ parser.add_argument("--show_log", type=str2bool, default=True)
147
+ parser.add_argument("--use_onnx", type=str2bool, default=False)
148
+ return parser
149
+
150
+
151
+ def parse_args():
152
+ parser = init_args()
153
+ return parser.parse_args()
154
+
155
+
156
+ def create_predictor(args, mode, logger):
157
+ if mode == "det":
158
+ model_dir = args.det_model_dir
159
+ elif mode == 'cls':
160
+ model_dir = args.cls_model_dir
161
+ elif mode == 'rec':
162
+ model_dir = args.rec_model_dir
163
+ elif mode == 'table':
164
+ model_dir = args.table_model_dir
165
+ elif mode == 'ser':
166
+ model_dir = args.ser_model_dir
167
+ elif mode == 're':
168
+ model_dir = args.re_model_dir
169
+ elif mode == "sr":
170
+ model_dir = args.sr_model_dir
171
+ elif mode == 'layout':
172
+ model_dir = args.layout_model_dir
173
+ else:
174
+ model_dir = args.e2e_model_dir
175
+
176
+ if model_dir is None:
177
+ logger.info("not find {} model file path {}".format(mode, model_dir))
178
+ sys.exit(0)
179
+ if args.use_onnx:
180
+ import onnxruntime as ort
181
+ model_file_path = model_dir
182
+ if not os.path.exists(model_file_path):
183
+ raise ValueError("not find model file path {}".format(
184
+ model_file_path))
185
+ sess = ort.InferenceSession(model_file_path)
186
+ return sess, sess.get_inputs()[0], None, None
187
+
188
+ else:
189
+ file_names = ['model', 'inference']
190
+ for file_name in file_names:
191
+ model_file_path = '{}/{}.pdmodel'.format(model_dir, file_name)
192
+ params_file_path = '{}/{}.pdiparams'.format(model_dir, file_name)
193
+ if os.path.exists(model_file_path) and os.path.exists(
194
+ params_file_path):
195
+ break
196
+ if not os.path.exists(model_file_path):
197
+ raise ValueError(
198
+ "not find model.pdmodel or inference.pdmodel in {}".format(
199
+ model_dir))
200
+ if not os.path.exists(params_file_path):
201
+ raise ValueError(
202
+ "not find model.pdiparams or inference.pdiparams in {}".format(
203
+ model_dir))
204
+
205
+ config = inference.Config(model_file_path, params_file_path)
206
+
207
+ if hasattr(args, 'precision'):
208
+ if args.precision == "fp16" and args.use_tensorrt:
209
+ precision = inference.PrecisionType.Half
210
+ elif args.precision == "int8":
211
+ precision = inference.PrecisionType.Int8
212
+ else:
213
+ precision = inference.PrecisionType.Float32
214
+ else:
215
+ precision = inference.PrecisionType.Float32
216
+
217
+ if args.use_gpu:
218
+ gpu_id = get_infer_gpuid()
219
+ if gpu_id is None:
220
+ logger.warning(
221
+ "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
222
+ )
223
+ config.enable_use_gpu(args.gpu_mem, args.gpu_id)
224
+ if args.use_tensorrt:
225
+ config.enable_tensorrt_engine(
226
+ workspace_size=1 << 30,
227
+ precision_mode=precision,
228
+ max_batch_size=args.max_batch_size,
229
+ min_subgraph_size=args.
230
+ min_subgraph_size, # skip the minmum trt subgraph
231
+ use_calib_mode=False)
232
+
233
+ # collect shape
234
+ trt_shape_f = os.path.join(model_dir,
235
+ f"{mode}_trt_dynamic_shape.txt")
236
+
237
+ if not os.path.exists(trt_shape_f):
238
+ config.collect_shape_range_info(trt_shape_f)
239
+ logger.info(
240
+ f"collect dynamic shape info into : {trt_shape_f}")
241
+ try:
242
+ config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
243
+ True)
244
+ except Exception as E:
245
+ logger.info(E)
246
+ logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
247
+
248
+ elif args.use_npu:
249
+ config.enable_custom_device("npu")
250
+ elif args.use_xpu:
251
+ config.enable_xpu(10 * 1024 * 1024)
252
+ else:
253
+ config.disable_gpu()
254
+ if args.enable_mkldnn:
255
+ # cache 10 different shapes for mkldnn to avoid memory leak
256
+ config.set_mkldnn_cache_capacity(10)
257
+ config.enable_mkldnn()
258
+ if args.precision == "fp16":
259
+ config.enable_mkldnn_bfloat16()
260
+ if hasattr(args, "cpu_threads"):
261
+ config.set_cpu_math_library_num_threads(args.cpu_threads)
262
+ else:
263
+ # default cpu threads as 10
264
+ config.set_cpu_math_library_num_threads(10)
265
+ # enable memory optim
266
+ config.enable_memory_optim()
267
+ config.disable_glog_info()
268
+ config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
269
+ config.delete_pass("matmul_transpose_reshape_fuse_pass")
270
+ if mode == 're':
271
+ config.delete_pass("simplify_with_basic_ops_pass")
272
+ if mode == 'table':
273
+ config.delete_pass("fc_fuse_pass") # not supported for table
274
+ config.switch_use_feed_fetch_ops(False)
275
+ config.switch_ir_optim(True)
276
+
277
+ # create predictor
278
+ predictor = inference.create_predictor(config)
279
+ input_names = predictor.get_input_names()
280
+ if mode in ['ser', 're']:
281
+ input_tensor = []
282
+ for name in input_names:
283
+ input_tensor.append(predictor.get_input_handle(name))
284
+ else:
285
+ for name in input_names:
286
+ input_tensor = predictor.get_input_handle(name)
287
+ output_tensors = get_output_tensors(args, mode, predictor)
288
+ return predictor, input_tensor, output_tensors, config
289
+
290
+
291
+ def get_output_tensors(args, mode, predictor):
292
+ output_names = predictor.get_output_names()
293
+ output_tensors = []
294
+ if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
295
+ output_name = 'softmax_0.tmp_0'
296
+ if output_name in output_names:
297
+ return [predictor.get_output_handle(output_name)]
298
+ else:
299
+ for output_name in output_names:
300
+ output_tensor = predictor.get_output_handle(output_name)
301
+ output_tensors.append(output_tensor)
302
+ else:
303
+ for output_name in output_names:
304
+ output_tensor = predictor.get_output_handle(output_name)
305
+ output_tensors.append(output_tensor)
306
+ return output_tensors
307
+
308
+
309
+ def get_infer_gpuid():
310
+ sysstr = platform.system()
311
+ if sysstr == "Windows":
312
+ return 0
313
+
314
+ if not paddle.fluid.core.is_compiled_with_rocm():
315
+ cmd = "env | grep CUDA_VISIBLE_DEVICES"
316
+ else:
317
+ cmd = "env | grep HIP_VISIBLE_DEVICES"
318
+ env_cuda = os.popen(cmd).readlines()
319
+ if len(env_cuda) == 0:
320
+ return 0
321
+ else:
322
+ gpu_id = env_cuda[0].strip().split("=")[1]
323
+ return int(gpu_id[0])
324
+
325
+
326
+ def draw_e2e_res(dt_boxes, strs, img_path):
327
+ src_im = cv2.imread(img_path)
328
+ for box, str in zip(dt_boxes, strs):
329
+ box = box.astype(np.int32).reshape((-1, 1, 2))
330
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
331
+ cv2.putText(
332
+ src_im,
333
+ str,
334
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
335
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
336
+ fontScale=0.7,
337
+ color=(0, 255, 0),
338
+ thickness=1)
339
+ return src_im
340
+
341
+
342
+ def draw_text_det_res(dt_boxes, img):
343
+ for box in dt_boxes:
344
+ box = np.array(box).astype(np.int32).reshape(-1, 2)
345
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
346
+ return img
347
+
348
+
349
+ def resize_img(img, input_size=600):
350
+ """
351
+ resize img and limit the longest side of the image to input_size
352
+ """
353
+ img = np.array(img)
354
+ im_shape = img.shape
355
+ im_size_max = np.max(im_shape[0:2])
356
+ im_scale = float(input_size) / float(im_size_max)
357
+ img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
358
+ return img
359
+
360
+
361
+ def draw_ocr(image,
362
+ boxes,
363
+ txts=None,
364
+ scores=None,
365
+ drop_score=0.5,
366
+ font_path="./doc/fonts/simfang.ttf"):
367
+ """
368
+ Visualize the results of OCR detection and recognition
369
+ args:
370
+ image(Image|array): RGB image
371
+ boxes(list): boxes with shape(N, 4, 2)
372
+ txts(list): the texts
373
+ scores(list): txxs corresponding scores
374
+ drop_score(float): only scores greater than drop_threshold will be visualized
375
+ font_path: the path of font which is used to draw text
376
+ return(array):
377
+ the visualized img
378
+ """
379
+ if scores is None:
380
+ scores = [1] * len(boxes)
381
+ box_num = len(boxes)
382
+ for i in range(box_num):
383
+ if scores is not None and (scores[i] < drop_score or
384
+ math.isnan(scores[i])):
385
+ continue
386
+ box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
387
+ image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
388
+ if txts is not None:
389
+ img = np.array(resize_img(image, input_size=600))
390
+ txt_img = text_visual(
391
+ txts,
392
+ scores,
393
+ img_h=img.shape[0],
394
+ img_w=600,
395
+ threshold=drop_score,
396
+ font_path=font_path)
397
+ img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
398
+ return img
399
+ return image
400
+
401
+
402
+ def draw_ocr_box_txt(image,
403
+ boxes,
404
+ txts=None,
405
+ scores=None,
406
+ drop_score=0.5,
407
+ font_path="./doc/fonts/simfang.ttf"):
408
+ h, w = image.height, image.width
409
+ img_left = image.copy()
410
+ img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
411
+ random.seed(0)
412
+
413
+ draw_left = ImageDraw.Draw(img_left)
414
+ if txts is None or len(txts) != len(boxes):
415
+ txts = [None] * len(boxes)
416
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
417
+ if scores is not None and scores[idx] < drop_score:
418
+ continue
419
+ color = (random.randint(0, 255), random.randint(0, 255),
420
+ random.randint(0, 255))
421
+ draw_left.polygon(box, fill=color)
422
+ img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
423
+ pts = np.array(box, np.int32).reshape((-1, 1, 2))
424
+ cv2.polylines(img_right_text, [pts], True, color, 1)
425
+ img_right = cv2.bitwise_and(img_right, img_right_text)
426
+ img_left = Image.blend(image, img_left, 0.5)
427
+ img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
428
+ img_show.paste(img_left, (0, 0, w, h))
429
+ img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
430
+ return np.array(img_show)
431
+
432
+
433
+ def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
434
+ box_height = int(
435
+ math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
436
+ box_width = int(
437
+ math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
438
+
439
+ if box_height > 2 * box_width and box_height > 30:
440
+ img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255))
441
+ draw_text = ImageDraw.Draw(img_text)
442
+ if txt:
443
+ font = create_font(txt, (box_height, box_width), font_path)
444
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
445
+ img_text = img_text.transpose(Image.ROTATE_270)
446
+ else:
447
+ img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255))
448
+ draw_text = ImageDraw.Draw(img_text)
449
+ if txt:
450
+ font = create_font(txt, (box_width, box_height), font_path)
451
+ draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
452
+
453
+ pts1 = np.float32(
454
+ [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
455
+ pts2 = np.array(box, dtype=np.float32)
456
+ M = cv2.getPerspectiveTransform(pts1, pts2)
457
+
458
+ img_text = np.array(img_text, dtype=np.uint8)
459
+ img_right_text = cv2.warpPerspective(
460
+ img_text,
461
+ M,
462
+ img_size,
463
+ flags=cv2.INTER_NEAREST,
464
+ borderMode=cv2.BORDER_CONSTANT,
465
+ borderValue=(255, 255, 255))
466
+ return img_right_text
467
+
468
+
469
+ def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
470
+ font_size = int(sz[1] * 0.99)
471
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
472
+ length = font.getsize(txt)[0]
473
+ if length > sz[0]:
474
+ font_size = int(font_size * sz[0] / length)
475
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
476
+ return font
477
+
478
+
479
+ def str_count(s):
480
+ """
481
+ Count the number of Chinese characters,
482
+ a single English character and a single number
483
+ equal to half the length of Chinese characters.
484
+ args:
485
+ s(string): the input of string
486
+ return(int):
487
+ the number of Chinese characters
488
+ """
489
+ import string
490
+ count_zh = count_pu = 0
491
+ s_len = len(s)
492
+ en_dg_count = 0
493
+ for c in s:
494
+ if c in string.ascii_letters or c.isdigit() or c.isspace():
495
+ en_dg_count += 1
496
+ elif c.isalpha():
497
+ count_zh += 1
498
+ else:
499
+ count_pu += 1
500
+ return s_len - math.ceil(en_dg_count / 2)
501
+
502
+
503
+ def text_visual(texts,
504
+ scores,
505
+ img_h=400,
506
+ img_w=600,
507
+ threshold=0.,
508
+ font_path="./doc/simfang.ttf"):
509
+ """
510
+ create new blank img and draw txt on it
511
+ args:
512
+ texts(list): the text will be draw
513
+ scores(list|None): corresponding score of each txt
514
+ img_h(int): the height of blank img
515
+ img_w(int): the width of blank img
516
+ font_path: the path of font which is used to draw text
517
+ return(array):
518
+ """
519
+ if scores is not None:
520
+ assert len(texts) == len(
521
+ scores), "The number of txts and corresponding scores must match"
522
+
523
+ def create_blank_img():
524
+ blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
525
+ blank_img[:, img_w - 1:] = 0
526
+ blank_img = Image.fromarray(blank_img).convert("RGB")
527
+ draw_txt = ImageDraw.Draw(blank_img)
528
+ return blank_img, draw_txt
529
+
530
+ blank_img, draw_txt = create_blank_img()
531
+
532
+ font_size = 20
533
+ txt_color = (0, 0, 0)
534
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
535
+
536
+ gap = font_size + 5
537
+ txt_img_list = []
538
+ count, index = 1, 0
539
+ for idx, txt in enumerate(texts):
540
+ index += 1
541
+ if scores[idx] < threshold or math.isnan(scores[idx]):
542
+ index -= 1
543
+ continue
544
+ first_line = True
545
+ while str_count(txt) >= img_w // font_size - 4:
546
+ tmp = txt
547
+ txt = tmp[:img_w // font_size - 4]
548
+ if first_line:
549
+ new_txt = str(index) + ': ' + txt
550
+ first_line = False
551
+ else:
552
+ new_txt = ' ' + txt
553
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
554
+ txt = tmp[img_w // font_size - 4:]
555
+ if count >= img_h // gap - 1:
556
+ txt_img_list.append(np.array(blank_img))
557
+ blank_img, draw_txt = create_blank_img()
558
+ count = 0
559
+ count += 1
560
+ if first_line:
561
+ new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
562
+ else:
563
+ new_txt = " " + txt + " " + '%.3f' % (scores[idx])
564
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
565
+ # whether add new blank img or not
566
+ if count >= img_h // gap - 1 and idx + 1 < len(texts):
567
+ txt_img_list.append(np.array(blank_img))
568
+ blank_img, draw_txt = create_blank_img()
569
+ count = 0
570
+ count += 1
571
+ txt_img_list.append(np.array(blank_img))
572
+ if len(txt_img_list) == 1:
573
+ blank_img = np.array(txt_img_list[0])
574
+ else:
575
+ blank_img = np.concatenate(txt_img_list, axis=1)
576
+ return np.array(blank_img)
577
+
578
+
579
+ def base64_to_cv2(b64str):
580
+ import base64
581
+ data = base64.b64decode(b64str.encode('utf8'))
582
+ data = np.frombuffer(data, np.uint8)
583
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
584
+ return data
585
+
586
+
587
+ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
588
+ if scores is None:
589
+ scores = [1] * len(boxes)
590
+ for (box, score) in zip(boxes, scores):
591
+ if score < drop_score:
592
+ continue
593
+ box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
594
+ image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
595
+ return image
596
+
597
+
598
+ def get_rotate_crop_image(img, points):
599
+ '''
600
+ img_height, img_width = img.shape[0:2]
601
+ left = int(np.min(points[:, 0]))
602
+ right = int(np.max(points[:, 0]))
603
+ top = int(np.min(points[:, 1]))
604
+ bottom = int(np.max(points[:, 1]))
605
+ img_crop = img[top:bottom, left:right, :].copy()
606
+ points[:, 0] = points[:, 0] - left
607
+ points[:, 1] = points[:, 1] - top
608
+ '''
609
+ assert len(points) == 4, "shape of points must be 4*2"
610
+ img_crop_width = int(
611
+ max(
612
+ np.linalg.norm(points[0] - points[1]),
613
+ np.linalg.norm(points[2] - points[3])))
614
+ img_crop_height = int(
615
+ max(
616
+ np.linalg.norm(points[0] - points[3]),
617
+ np.linalg.norm(points[1] - points[2])))
618
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
619
+ [img_crop_width, img_crop_height],
620
+ [0, img_crop_height]])
621
+ M = cv2.getPerspectiveTransform(points, pts_std)
622
+ dst_img = cv2.warpPerspective(
623
+ img,
624
+ M, (img_crop_width, img_crop_height),
625
+ borderMode=cv2.BORDER_REPLICATE,
626
+ flags=cv2.INTER_CUBIC)
627
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
628
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
629
+ dst_img = np.rot90(dst_img)
630
+ return dst_img
631
+
632
+
633
+ def get_minarea_rect_crop(img, points):
634
+ bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
635
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
636
+
637
+ index_a, index_b, index_c, index_d = 0, 1, 2, 3
638
+ if points[1][1] > points[0][1]:
639
+ index_a = 0
640
+ index_d = 1
641
+ else:
642
+ index_a = 1
643
+ index_d = 0
644
+ if points[3][1] > points[2][1]:
645
+ index_b = 2
646
+ index_c = 3
647
+ else:
648
+ index_b = 3
649
+ index_c = 2
650
+
651
+ box = [points[index_a], points[index_b], points[index_c], points[index_d]]
652
+ crop_img = get_rotate_crop_image(img, np.array(box))
653
+ return crop_img
654
+
655
+
656
+ def check_gpu(use_gpu):
657
+ if use_gpu and not paddle.is_compiled_with_cuda():
658
+ use_gpu = False
659
+ return use_gpu
660
+
661
+
662
+ if __name__ == '__main__':
663
+ pass
tools/infer_cls.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+
24
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append(__dir__)
26
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27
+
28
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
29
+
30
+ import paddle
31
+
32
+ from ppocr.data import create_operators, transform
33
+ from ppocr.modeling.architectures import build_model
34
+ from ppocr.postprocess import build_post_process
35
+ from ppocr.utils.save_load import load_model
36
+ from ppocr.utils.utility import get_image_file_list
37
+ import tools.program as program
38
+
39
+
40
+ def main():
41
+ global_config = config['Global']
42
+
43
+ # build post process
44
+ post_process_class = build_post_process(config['PostProcess'],
45
+ global_config)
46
+
47
+ # build model
48
+ model = build_model(config['Architecture'])
49
+
50
+ load_model(config, model)
51
+
52
+ # create data ops
53
+ transforms = []
54
+ for op in config['Eval']['dataset']['transforms']:
55
+ op_name = list(op)[0]
56
+ if 'Label' in op_name:
57
+ continue
58
+ elif op_name == 'KeepKeys':
59
+ op[op_name]['keep_keys'] = ['image']
60
+ elif op_name == "SSLRotateResize":
61
+ op[op_name]["mode"] = "test"
62
+ transforms.append(op)
63
+ global_config['infer_mode'] = True
64
+ ops = create_operators(transforms, global_config)
65
+
66
+ model.eval()
67
+ for file in get_image_file_list(config['Global']['infer_img']):
68
+ logger.info("infer_img: {}".format(file))
69
+ with open(file, 'rb') as f:
70
+ img = f.read()
71
+ data = {'image': img}
72
+ batch = transform(data, ops)
73
+
74
+ images = np.expand_dims(batch[0], axis=0)
75
+ images = paddle.to_tensor(images)
76
+ preds = model(images)
77
+ post_result = post_process_class(preds)
78
+ for rec_result in post_result:
79
+ logger.info('\t result: {}'.format(rec_result))
80
+ logger.info("success!")
81
+
82
+
83
+ if __name__ == '__main__':
84
+ config, device, logger, vdl_writer = program.preprocess()
85
+ main()
tools/infer_det.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+
24
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append(__dir__)
26
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27
+
28
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
29
+
30
+ import cv2
31
+ import json
32
+ import paddle
33
+
34
+ from ppocr.data import create_operators, transform
35
+ from ppocr.modeling.architectures import build_model
36
+ from ppocr.postprocess import build_post_process
37
+ from ppocr.utils.save_load import load_model
38
+ from ppocr.utils.utility import get_image_file_list
39
+ import tools.program as program
40
+
41
+
42
+ def draw_det_res(dt_boxes, config, img, img_name, save_path):
43
+ if len(dt_boxes) > 0:
44
+ import cv2
45
+ src_im = img
46
+ for box in dt_boxes:
47
+ box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
48
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
49
+ if not os.path.exists(save_path):
50
+ os.makedirs(save_path)
51
+ save_path = os.path.join(save_path, os.path.basename(img_name))
52
+ cv2.imwrite(save_path, src_im)
53
+ logger.info("The detected Image saved in {}".format(save_path))
54
+
55
+
56
+ @paddle.no_grad()
57
+ def main():
58
+ global_config = config['Global']
59
+
60
+ # build model
61
+ model = build_model(config['Architecture'])
62
+
63
+ load_model(config, model)
64
+ # build post process
65
+ post_process_class = build_post_process(config['PostProcess'])
66
+
67
+ # create data ops
68
+ transforms = []
69
+ for op in config['Eval']['dataset']['transforms']:
70
+ op_name = list(op)[0]
71
+ if 'Label' in op_name:
72
+ continue
73
+ elif op_name == 'KeepKeys':
74
+ op[op_name]['keep_keys'] = ['image', 'shape']
75
+ transforms.append(op)
76
+
77
+ ops = create_operators(transforms, global_config)
78
+
79
+ save_res_path = config['Global']['save_res_path']
80
+ if not os.path.exists(os.path.dirname(save_res_path)):
81
+ os.makedirs(os.path.dirname(save_res_path))
82
+
83
+ model.eval()
84
+ with open(save_res_path, "wb") as fout:
85
+ for file in get_image_file_list(config['Global']['infer_img']):
86
+ logger.info("infer_img: {}".format(file))
87
+ with open(file, 'rb') as f:
88
+ img = f.read()
89
+ data = {'image': img}
90
+ batch = transform(data, ops)
91
+
92
+ images = np.expand_dims(batch[0], axis=0)
93
+ shape_list = np.expand_dims(batch[1], axis=0)
94
+ images = paddle.to_tensor(images)
95
+ preds = model(images)
96
+ post_result = post_process_class(preds, shape_list)
97
+
98
+ src_img = cv2.imread(file)
99
+
100
+ dt_boxes_json = []
101
+ # parser boxes if post_result is dict
102
+ if isinstance(post_result, dict):
103
+ det_box_json = {}
104
+ for k in post_result.keys():
105
+ boxes = post_result[k][0]['points']
106
+ dt_boxes_list = []
107
+ for box in boxes:
108
+ tmp_json = {"transcription": ""}
109
+ tmp_json['points'] = np.array(box).tolist()
110
+ dt_boxes_list.append(tmp_json)
111
+ det_box_json[k] = dt_boxes_list
112
+ save_det_path = os.path.dirname(config['Global'][
113
+ 'save_res_path']) + "/det_results_{}/".format(k)
114
+ draw_det_res(boxes, config, src_img, file, save_det_path)
115
+ else:
116
+ boxes = post_result[0]['points']
117
+ dt_boxes_json = []
118
+ # write result
119
+ for box in boxes:
120
+ tmp_json = {"transcription": ""}
121
+ tmp_json['points'] = np.array(box).tolist()
122
+ dt_boxes_json.append(tmp_json)
123
+ save_det_path = os.path.dirname(config['Global'][
124
+ 'save_res_path']) + "/det_results/"
125
+ draw_det_res(boxes, config, src_img, file, save_det_path)
126
+ otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
127
+ fout.write(otstr.encode())
128
+
129
+ logger.info("success!")
130
+
131
+
132
+ if __name__ == '__main__':
133
+ config, device, logger, vdl_writer = program.preprocess()
134
+ main()
tools/infer_e2e.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+
24
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append(__dir__)
26
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27
+
28
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
29
+
30
+ import cv2
31
+ import json
32
+ import paddle
33
+
34
+ from ppocr.data import create_operators, transform
35
+ from ppocr.modeling.architectures import build_model
36
+ from ppocr.postprocess import build_post_process
37
+ from ppocr.utils.save_load import load_model
38
+ from ppocr.utils.utility import get_image_file_list
39
+ import tools.program as program
40
+ from PIL import Image, ImageDraw, ImageFont
41
+ import math
42
+
43
+
44
+ def draw_e2e_res_for_chinese(image,
45
+ boxes,
46
+ txts,
47
+ config,
48
+ img_name,
49
+ font_path="./doc/simfang.ttf"):
50
+ h, w = image.height, image.width
51
+ img_left = image.copy()
52
+ img_right = Image.new('RGB', (w, h), (255, 255, 255))
53
+
54
+ import random
55
+
56
+ random.seed(0)
57
+ draw_left = ImageDraw.Draw(img_left)
58
+ draw_right = ImageDraw.Draw(img_right)
59
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
60
+ box = np.array(box)
61
+ box = [tuple(x) for x in box]
62
+ color = (random.randint(0, 255), random.randint(0, 255),
63
+ random.randint(0, 255))
64
+ draw_left.polygon(box, fill=color)
65
+ draw_right.polygon(box, outline=color)
66
+ font = ImageFont.truetype(font_path, 15, encoding="utf-8")
67
+ draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
68
+ img_left = Image.blend(image, img_left, 0.5)
69
+ img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
70
+ img_show.paste(img_left, (0, 0, w, h))
71
+ img_show.paste(img_right, (w, 0, w * 2, h))
72
+
73
+ save_e2e_path = os.path.dirname(config['Global'][
74
+ 'save_res_path']) + "/e2e_results/"
75
+ if not os.path.exists(save_e2e_path):
76
+ os.makedirs(save_e2e_path)
77
+ save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
78
+ cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
79
+ logger.info("The e2e Image saved in {}".format(save_path))
80
+
81
+
82
+ def draw_e2e_res(dt_boxes, strs, config, img, img_name):
83
+ if len(dt_boxes) > 0:
84
+ src_im = img
85
+ for box, str in zip(dt_boxes, strs):
86
+ box = box.astype(np.int32).reshape((-1, 1, 2))
87
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
88
+ cv2.putText(
89
+ src_im,
90
+ str,
91
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
92
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
93
+ fontScale=0.7,
94
+ color=(0, 255, 0),
95
+ thickness=1)
96
+ save_det_path = os.path.dirname(config['Global'][
97
+ 'save_res_path']) + "/e2e_results/"
98
+ if not os.path.exists(save_det_path):
99
+ os.makedirs(save_det_path)
100
+ save_path = os.path.join(save_det_path, os.path.basename(img_name))
101
+ cv2.imwrite(save_path, src_im)
102
+ logger.info("The e2e Image saved in {}".format(save_path))
103
+
104
+
105
+ def main():
106
+ global_config = config['Global']
107
+
108
+ # build model
109
+ model = build_model(config['Architecture'])
110
+
111
+ load_model(config, model)
112
+
113
+ # build post process
114
+ post_process_class = build_post_process(config['PostProcess'],
115
+ global_config)
116
+
117
+ # create data ops
118
+ transforms = []
119
+ for op in config['Eval']['dataset']['transforms']:
120
+ op_name = list(op)[0]
121
+ if 'Label' in op_name:
122
+ continue
123
+ elif op_name == 'KeepKeys':
124
+ op[op_name]['keep_keys'] = ['image', 'shape']
125
+ transforms.append(op)
126
+
127
+ ops = create_operators(transforms, global_config)
128
+
129
+ save_res_path = config['Global']['save_res_path']
130
+ if not os.path.exists(os.path.dirname(save_res_path)):
131
+ os.makedirs(os.path.dirname(save_res_path))
132
+
133
+ model.eval()
134
+ with open(save_res_path, "wb") as fout:
135
+ for file in get_image_file_list(config['Global']['infer_img']):
136
+ logger.info("infer_img: {}".format(file))
137
+ with open(file, 'rb') as f:
138
+ img = f.read()
139
+ data = {'image': img}
140
+ batch = transform(data, ops)
141
+ images = np.expand_dims(batch[0], axis=0)
142
+ shape_list = np.expand_dims(batch[1], axis=0)
143
+ images = paddle.to_tensor(images)
144
+ preds = model(images)
145
+ post_result = post_process_class(preds, shape_list)
146
+ points, strs = post_result['points'], post_result['texts']
147
+ # write result
148
+ dt_boxes_json = []
149
+ for poly, str in zip(points, strs):
150
+ tmp_json = {"transcription": str}
151
+ tmp_json['points'] = poly.tolist()
152
+ dt_boxes_json.append(tmp_json)
153
+ otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
154
+ fout.write(otstr.encode())
155
+ src_img = cv2.imread(file)
156
+ if global_config['infer_visual_type'] == 'EN':
157
+ draw_e2e_res(points, strs, config, src_img, file)
158
+ elif global_config['infer_visual_type'] == 'CN':
159
+ src_img = Image.fromarray(
160
+ cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
161
+ draw_e2e_res_for_chinese(
162
+ src_img,
163
+ points,
164
+ strs,
165
+ config,
166
+ file,
167
+ font_path="./doc/fonts/simfang.ttf")
168
+
169
+ logger.info("success!")
170
+
171
+
172
+ if __name__ == '__main__':
173
+ config, device, logger, vdl_writer = program.preprocess()
174
+ main()
tools/infer_kie.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+ import paddle.nn.functional as F
21
+
22
+ import os
23
+ import sys
24
+
25
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append(__dir__)
27
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
28
+
29
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
30
+
31
+ import cv2
32
+ import paddle
33
+
34
+ from ppocr.data import create_operators, transform
35
+ from ppocr.modeling.architectures import build_model
36
+ from ppocr.utils.save_load import load_model
37
+ import tools.program as program
38
+ import time
39
+
40
+
41
+ def read_class_list(filepath):
42
+ ret = {}
43
+ with open(filepath, "r") as f:
44
+ lines = f.readlines()
45
+ for idx, line in enumerate(lines):
46
+ ret[idx] = line.strip("\n")
47
+ return ret
48
+
49
+
50
+ def draw_kie_result(batch, node, idx_to_cls, count):
51
+ img = batch[6].copy()
52
+ boxes = batch[7]
53
+ h, w = img.shape[:2]
54
+ pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
55
+ max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
56
+ node_pred_label = max_idx.numpy().tolist()
57
+ node_pred_score = max_value.numpy().tolist()
58
+
59
+ for i, box in enumerate(boxes):
60
+ if i >= len(node_pred_label):
61
+ break
62
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
63
+ [box[0], box[3]]]
64
+ Pts = np.array([new_box], np.int32)
65
+ cv2.polylines(
66
+ img, [Pts.reshape((-1, 1, 2))],
67
+ True,
68
+ color=(255, 255, 0),
69
+ thickness=1)
70
+ x_min = int(min([point[0] for point in new_box]))
71
+ y_min = int(min([point[1] for point in new_box]))
72
+
73
+ pred_label = node_pred_label[i]
74
+ if pred_label in idx_to_cls:
75
+ pred_label = idx_to_cls[pred_label]
76
+ pred_score = '{:.2f}'.format(node_pred_score[i])
77
+ text = pred_label + '(' + pred_score + ')'
78
+ cv2.putText(pred_img, text, (x_min * 2, y_min),
79
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
80
+ vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
81
+ vis_img[:, :w] = img
82
+ vis_img[:, w:] = pred_img
83
+ save_kie_path = os.path.dirname(config['Global'][
84
+ 'save_res_path']) + "/kie_results/"
85
+ if not os.path.exists(save_kie_path):
86
+ os.makedirs(save_kie_path)
87
+ save_path = os.path.join(save_kie_path, str(count) + ".png")
88
+ cv2.imwrite(save_path, vis_img)
89
+ logger.info("The Kie Image saved in {}".format(save_path))
90
+
91
+ def write_kie_result(fout, node, data):
92
+ """
93
+ Write infer result to output file, sorted by the predict label of each line.
94
+ The format keeps the same as the input with additional score attribute.
95
+ """
96
+ import json
97
+ label = data['label']
98
+ annotations = json.loads(label)
99
+ max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
100
+ node_pred_label = max_idx.numpy().tolist()
101
+ node_pred_score = max_value.numpy().tolist()
102
+ res = []
103
+ for i, label in enumerate(node_pred_label):
104
+ pred_score = '{:.2f}'.format(node_pred_score[i])
105
+ pred_res = {
106
+ 'label': label,
107
+ 'transcription': annotations[i]['transcription'],
108
+ 'score': pred_score,
109
+ 'points': annotations[i]['points'],
110
+ }
111
+ res.append(pred_res)
112
+ res.sort(key=lambda x: x['label'])
113
+ fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
114
+
115
+ def main():
116
+ global_config = config['Global']
117
+
118
+ # build model
119
+ model = build_model(config['Architecture'])
120
+ load_model(config, model)
121
+
122
+ # create data ops
123
+ transforms = []
124
+ for op in config['Eval']['dataset']['transforms']:
125
+ transforms.append(op)
126
+
127
+ data_dir = config['Eval']['dataset']['data_dir']
128
+
129
+ ops = create_operators(transforms, global_config)
130
+
131
+ save_res_path = config['Global']['save_res_path']
132
+ class_path = config['Global']['class_path']
133
+ idx_to_cls = read_class_list(class_path)
134
+ os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
135
+
136
+ model.eval()
137
+
138
+ warmup_times = 0
139
+ count_t = []
140
+ with open(save_res_path, "w") as fout:
141
+ with open(config['Global']['infer_img'], "rb") as f:
142
+ lines = f.readlines()
143
+ for index, data_line in enumerate(lines):
144
+ if index == 10:
145
+ warmup_t = time.time()
146
+ data_line = data_line.decode('utf-8')
147
+ substr = data_line.strip("\n").split("\t")
148
+ img_path, label = data_dir + "/" + substr[0], substr[1]
149
+ data = {'img_path': img_path, 'label': label}
150
+ with open(data['img_path'], 'rb') as f:
151
+ img = f.read()
152
+ data['image'] = img
153
+ st = time.time()
154
+ batch = transform(data, ops)
155
+ batch_pred = [0] * len(batch)
156
+ for i in range(len(batch)):
157
+ batch_pred[i] = paddle.to_tensor(
158
+ np.expand_dims(
159
+ batch[i], axis=0))
160
+ st = time.time()
161
+ node, edge = model(batch_pred)
162
+ node = F.softmax(node, -1)
163
+ count_t.append(time.time() - st)
164
+ draw_kie_result(batch, node, idx_to_cls, index)
165
+ write_kie_result(fout, node, data)
166
+ fout.close()
167
+ logger.info("success!")
168
+ logger.info("It took {} s for predict {} images.".format(
169
+ np.sum(count_t), len(count_t)))
170
+ ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
171
+ logger.info("The ips is {} images/s".format(ips))
172
+
173
+
174
+ if __name__ == '__main__':
175
+ config, device, logger, vdl_writer = program.preprocess()
176
+ main()
tools/infer_kie_token_ser.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+
24
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append(__dir__)
26
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27
+
28
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
29
+ import cv2
30
+ import json
31
+ import paddle
32
+
33
+ from ppocr.data import create_operators, transform
34
+ from ppocr.modeling.architectures import build_model
35
+ from ppocr.postprocess import build_post_process
36
+ from ppocr.utils.save_load import load_model
37
+ from ppocr.utils.visual import draw_ser_results
38
+ from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
39
+ import tools.program as program
40
+
41
+
42
+ def to_tensor(data):
43
+ import numbers
44
+ from collections import defaultdict
45
+ data_dict = defaultdict(list)
46
+ to_tensor_idxs = []
47
+
48
+ for idx, v in enumerate(data):
49
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
50
+ if idx not in to_tensor_idxs:
51
+ to_tensor_idxs.append(idx)
52
+ data_dict[idx].append(v)
53
+ for idx in to_tensor_idxs:
54
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
55
+ return list(data_dict.values())
56
+
57
+
58
+ class SerPredictor(object):
59
+ def __init__(self, config):
60
+ global_config = config['Global']
61
+ self.algorithm = config['Architecture']["algorithm"]
62
+
63
+ # build post process
64
+ self.post_process_class = build_post_process(config['PostProcess'],
65
+ global_config)
66
+
67
+ # build model
68
+ self.model = build_model(config['Architecture'])
69
+
70
+ load_model(
71
+ config, self.model, model_type=config['Architecture']["model_type"])
72
+
73
+ from paddleocr import PaddleOCR
74
+
75
+ self.ocr_engine = PaddleOCR(
76
+ use_angle_cls=False,
77
+ show_log=False,
78
+ rec_model_dir=global_config.get("kie_rec_model_dir", None),
79
+ det_model_dir=global_config.get("kie_det_model_dir", None),
80
+ use_gpu=global_config['use_gpu'])
81
+
82
+ # create data ops
83
+ transforms = []
84
+ for op in config['Eval']['dataset']['transforms']:
85
+ op_name = list(op)[0]
86
+ if 'Label' in op_name:
87
+ op[op_name]['ocr_engine'] = self.ocr_engine
88
+ elif op_name == 'KeepKeys':
89
+ op[op_name]['keep_keys'] = [
90
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
91
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
92
+ 'entities'
93
+ ]
94
+
95
+ transforms.append(op)
96
+ if config["Global"].get("infer_mode", None) is None:
97
+ global_config['infer_mode'] = True
98
+ self.ops = create_operators(config['Eval']['dataset']['transforms'],
99
+ global_config)
100
+ self.model.eval()
101
+
102
+ def __call__(self, data):
103
+ with open(data["img_path"], 'rb') as f:
104
+ img = f.read()
105
+ data["image"] = img
106
+ batch = transform(data, self.ops)
107
+ batch = to_tensor(batch)
108
+ preds = self.model(batch)
109
+
110
+ post_result = self.post_process_class(
111
+ preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
112
+ return post_result, batch
113
+
114
+
115
+ if __name__ == '__main__':
116
+ config, device, logger, vdl_writer = program.preprocess()
117
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
118
+
119
+ ser_engine = SerPredictor(config)
120
+
121
+ if config["Global"].get("infer_mode", None) is False:
122
+ data_dir = config['Eval']['dataset']['data_dir']
123
+ with open(config['Global']['infer_img'], "rb") as f:
124
+ infer_imgs = f.readlines()
125
+ else:
126
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
127
+
128
+ with open(
129
+ os.path.join(config['Global']['save_res_path'],
130
+ "infer_results.txt"),
131
+ "w",
132
+ encoding='utf-8') as fout:
133
+ for idx, info in enumerate(infer_imgs):
134
+ if config["Global"].get("infer_mode", None) is False:
135
+ data_line = info.decode('utf-8')
136
+ substr = data_line.strip("\n").split("\t")
137
+ img_path = os.path.join(data_dir, substr[0])
138
+ data = {'img_path': img_path, 'label': substr[1]}
139
+ else:
140
+ img_path = info
141
+ data = {'img_path': img_path}
142
+
143
+ save_img_path = os.path.join(
144
+ config['Global']['save_res_path'],
145
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
146
+
147
+ result, _ = ser_engine(data)
148
+ result = result[0]
149
+ fout.write(img_path + "\t" + json.dumps(
150
+ {
151
+ "ocr_info": result,
152
+ }, ensure_ascii=False) + "\n")
153
+ img_res = draw_ser_results(img_path, result)
154
+ cv2.imwrite(save_img_path, img_res)
155
+
156
+ logger.info("process: [{}/{}], save result to {}".format(
157
+ idx, len(infer_imgs), save_img_path))
tools/infer_kie_token_ser_re.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+
24
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append(__dir__)
26
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
27
+
28
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
29
+ import cv2
30
+ import json
31
+ import paddle
32
+ import paddle.distributed as dist
33
+
34
+ from ppocr.data import create_operators, transform
35
+ from ppocr.modeling.architectures import build_model
36
+ from ppocr.postprocess import build_post_process
37
+ from ppocr.utils.save_load import load_model
38
+ from ppocr.utils.visual import draw_re_results
39
+ from ppocr.utils.logging import get_logger
40
+ from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
41
+ from tools.program import ArgsParser, load_config, merge_config
42
+ from tools.infer_kie_token_ser import SerPredictor
43
+
44
+
45
+ class ReArgsParser(ArgsParser):
46
+ def __init__(self):
47
+ super(ReArgsParser, self).__init__()
48
+ self.add_argument(
49
+ "-c_ser", "--config_ser", help="ser configuration file to use")
50
+ self.add_argument(
51
+ "-o_ser",
52
+ "--opt_ser",
53
+ nargs='+',
54
+ help="set ser configuration options ")
55
+
56
+ def parse_args(self, argv=None):
57
+ args = super(ReArgsParser, self).parse_args(argv)
58
+ assert args.config_ser is not None, \
59
+ "Please specify --config_ser=ser_configure_file_path."
60
+ args.opt_ser = self._parse_opt(args.opt_ser)
61
+ return args
62
+
63
+
64
+ def make_input(ser_inputs, ser_results):
65
+ entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
66
+ batch_size, max_seq_len = ser_inputs[0].shape[:2]
67
+ entities = ser_inputs[8][0]
68
+ ser_results = ser_results[0]
69
+ assert len(entities) == len(ser_results)
70
+
71
+ # entities
72
+ start = []
73
+ end = []
74
+ label = []
75
+ entity_idx_dict = {}
76
+ for i, (res, entity) in enumerate(zip(ser_results, entities)):
77
+ if res['pred'] == 'O':
78
+ continue
79
+ entity_idx_dict[len(start)] = i
80
+ start.append(entity['start'])
81
+ end.append(entity['end'])
82
+ label.append(entities_labels[res['pred']])
83
+
84
+ entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
85
+ entities[0, 0] = len(start)
86
+ entities[1:len(start) + 1, 0] = start
87
+ entities[0, 1] = len(end)
88
+ entities[1:len(end) + 1, 1] = end
89
+ entities[0, 2] = len(label)
90
+ entities[1:len(label) + 1, 2] = label
91
+
92
+ # relations
93
+ head = []
94
+ tail = []
95
+ for i in range(len(label)):
96
+ for j in range(len(label)):
97
+ if label[i] == 1 and label[j] == 2:
98
+ head.append(i)
99
+ tail.append(j)
100
+
101
+ relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
102
+ relations[0, 0] = len(head)
103
+ relations[1:len(head) + 1, 0] = head
104
+ relations[0, 1] = len(tail)
105
+ relations[1:len(tail) + 1, 1] = tail
106
+
107
+ entities = np.expand_dims(entities, axis=0)
108
+ entities = np.repeat(entities, batch_size, axis=0)
109
+ relations = np.expand_dims(relations, axis=0)
110
+ relations = np.repeat(relations, batch_size, axis=0)
111
+
112
+ # remove ocr_info segment_offset_id and label in ser input
113
+ if isinstance(ser_inputs[0], paddle.Tensor):
114
+ entities = paddle.to_tensor(entities)
115
+ relations = paddle.to_tensor(relations)
116
+ ser_inputs = ser_inputs[:5] + [entities, relations]
117
+
118
+ entity_idx_dict_batch = []
119
+ for b in range(batch_size):
120
+ entity_idx_dict_batch.append(entity_idx_dict)
121
+ return ser_inputs, entity_idx_dict_batch
122
+
123
+
124
+ class SerRePredictor(object):
125
+ def __init__(self, config, ser_config):
126
+ global_config = config['Global']
127
+ if "infer_mode" in global_config:
128
+ ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
129
+
130
+ self.ser_engine = SerPredictor(ser_config)
131
+
132
+ # init re model
133
+
134
+ # build post process
135
+ self.post_process_class = build_post_process(config['PostProcess'],
136
+ global_config)
137
+
138
+ # build model
139
+ self.model = build_model(config['Architecture'])
140
+
141
+ load_model(
142
+ config, self.model, model_type=config['Architecture']["model_type"])
143
+
144
+ self.model.eval()
145
+
146
+ def __call__(self, data):
147
+ ser_results, ser_inputs = self.ser_engine(data)
148
+ re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
149
+ if self.model.backbone.use_visual_backbone is False:
150
+ re_input.pop(4)
151
+ preds = self.model(re_input)
152
+ post_result = self.post_process_class(
153
+ preds,
154
+ ser_results=ser_results,
155
+ entity_idx_dict_batch=entity_idx_dict_batch)
156
+ return post_result
157
+
158
+
159
+ def preprocess():
160
+ FLAGS = ReArgsParser().parse_args()
161
+ config = load_config(FLAGS.config)
162
+ config = merge_config(config, FLAGS.opt)
163
+
164
+ ser_config = load_config(FLAGS.config_ser)
165
+ ser_config = merge_config(ser_config, FLAGS.opt_ser)
166
+
167
+ logger = get_logger()
168
+
169
+ # check if set use_gpu=True in paddlepaddle cpu version
170
+ use_gpu = config['Global']['use_gpu']
171
+
172
+ device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
173
+ device = paddle.set_device(device)
174
+
175
+ logger.info('{} re config {}'.format('*' * 10, '*' * 10))
176
+ print_dict(config, logger)
177
+ logger.info('\n')
178
+ logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
179
+ print_dict(ser_config, logger)
180
+ logger.info('train with paddle {} and device {}'.format(paddle.__version__,
181
+ device))
182
+ return config, ser_config, device, logger
183
+
184
+
185
+ if __name__ == '__main__':
186
+ config, ser_config, device, logger = preprocess()
187
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
188
+
189
+ ser_re_engine = SerRePredictor(config, ser_config)
190
+
191
+ if config["Global"].get("infer_mode", None) is False:
192
+ data_dir = config['Eval']['dataset']['data_dir']
193
+ with open(config['Global']['infer_img'], "rb") as f:
194
+ infer_imgs = f.readlines()
195
+ else:
196
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
197
+
198
+ with open(
199
+ os.path.join(config['Global']['save_res_path'],
200
+ "infer_results.txt"),
201
+ "w",
202
+ encoding='utf-8') as fout:
203
+ for idx, info in enumerate(infer_imgs):
204
+ if config["Global"].get("infer_mode", None) is False:
205
+ data_line = info.decode('utf-8')
206
+ substr = data_line.strip("\n").split("\t")
207
+ img_path = os.path.join(data_dir, substr[0])
208
+ data = {'img_path': img_path, 'label': substr[1]}
209
+ else:
210
+ img_path = info
211
+ data = {'img_path': img_path}
212
+
213
+ save_img_path = os.path.join(
214
+ config['Global']['save_res_path'],
215
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
216
+
217
+ result = ser_re_engine(data)
218
+ result = result[0]
219
+ fout.write(img_path + "\t" + json.dumps(
220
+ result, ensure_ascii=False) + "\n")
221
+ img_res = draw_re_results(img_path, result)
222
+ cv2.imwrite(save_img_path, img_res)
223
+
224
+ logger.info("process: [{}/{}], save result to {}".format(
225
+ idx, len(infer_imgs), save_img_path))
tools/infer_rec.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+ import json
24
+
25
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append(__dir__)
27
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
28
+
29
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
30
+
31
+ import paddle
32
+
33
+ from ppocr.data import create_operators, transform
34
+ from ppocr.modeling.architectures import build_model
35
+ from ppocr.postprocess import build_post_process
36
+ from ppocr.utils.save_load import load_model
37
+ from ppocr.utils.utility import get_image_file_list
38
+ import tools.program as program
39
+
40
+
41
+ def main():
42
+ global_config = config['Global']
43
+
44
+ # build post process
45
+ post_process_class = build_post_process(config['PostProcess'],
46
+ global_config)
47
+
48
+ # build model
49
+ if hasattr(post_process_class, 'character'):
50
+ char_num = len(getattr(post_process_class, 'character'))
51
+ if config['Architecture']["algorithm"] in ["Distillation",
52
+ ]: # distillation model
53
+ for key in config['Architecture']["Models"]:
54
+ if config['Architecture']['Models'][key]['Head'][
55
+ 'name'] == 'MultiHead': # for multi head
56
+ out_channels_list = {}
57
+ if config['PostProcess'][
58
+ 'name'] == 'DistillationSARLabelDecode':
59
+ char_num = char_num - 2
60
+ out_channels_list['CTCLabelDecode'] = char_num
61
+ out_channels_list['SARLabelDecode'] = char_num + 2
62
+ config['Architecture']['Models'][key]['Head'][
63
+ 'out_channels_list'] = out_channels_list
64
+ else:
65
+ config['Architecture']["Models"][key]["Head"][
66
+ 'out_channels'] = char_num
67
+ elif config['Architecture']['Head'][
68
+ 'name'] == 'MultiHead': # for multi head loss
69
+ out_channels_list = {}
70
+ if config['PostProcess']['name'] == 'SARLabelDecode':
71
+ char_num = char_num - 2
72
+ out_channels_list['CTCLabelDecode'] = char_num
73
+ out_channels_list['SARLabelDecode'] = char_num + 2
74
+ config['Architecture']['Head'][
75
+ 'out_channels_list'] = out_channels_list
76
+ else: # base rec model
77
+ config['Architecture']["Head"]['out_channels'] = char_num
78
+
79
+ model = build_model(config['Architecture'])
80
+
81
+ load_model(config, model)
82
+
83
+ # create data ops
84
+ transforms = []
85
+ for op in config['Eval']['dataset']['transforms']:
86
+ op_name = list(op)[0]
87
+ if 'Label' in op_name:
88
+ continue
89
+ elif op_name in ['RecResizeImg']:
90
+ op[op_name]['infer_mode'] = True
91
+ elif op_name == 'KeepKeys':
92
+ if config['Architecture']['algorithm'] == "SRN":
93
+ op[op_name]['keep_keys'] = [
94
+ 'image', 'encoder_word_pos', 'gsrm_word_pos',
95
+ 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
96
+ ]
97
+ elif config['Architecture']['algorithm'] == "SAR":
98
+ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
99
+ elif config['Architecture']['algorithm'] == "RobustScanner":
100
+ op[op_name][
101
+ 'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
102
+ else:
103
+ op[op_name]['keep_keys'] = ['image']
104
+ transforms.append(op)
105
+ global_config['infer_mode'] = True
106
+ ops = create_operators(transforms, global_config)
107
+
108
+ save_res_path = config['Global'].get('save_res_path',
109
+ "./output/rec/predicts_rec.txt")
110
+ if not os.path.exists(os.path.dirname(save_res_path)):
111
+ os.makedirs(os.path.dirname(save_res_path))
112
+
113
+ model.eval()
114
+
115
+ with open(save_res_path, "w") as fout:
116
+ for file in get_image_file_list(config['Global']['infer_img']):
117
+ logger.info("infer_img: {}".format(file))
118
+ with open(file, 'rb') as f:
119
+ img = f.read()
120
+ data = {'image': img}
121
+ batch = transform(data, ops)
122
+ if config['Architecture']['algorithm'] == "SRN":
123
+ encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
124
+ gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
125
+ gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
126
+ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
127
+
128
+ others = [
129
+ paddle.to_tensor(encoder_word_pos_list),
130
+ paddle.to_tensor(gsrm_word_pos_list),
131
+ paddle.to_tensor(gsrm_slf_attn_bias1_list),
132
+ paddle.to_tensor(gsrm_slf_attn_bias2_list)
133
+ ]
134
+ if config['Architecture']['algorithm'] == "SAR":
135
+ valid_ratio = np.expand_dims(batch[-1], axis=0)
136
+ img_metas = [paddle.to_tensor(valid_ratio)]
137
+ if config['Architecture']['algorithm'] == "RobustScanner":
138
+ valid_ratio = np.expand_dims(batch[1], axis=0)
139
+ word_positons = np.expand_dims(batch[2], axis=0)
140
+ img_metas = [
141
+ paddle.to_tensor(valid_ratio),
142
+ paddle.to_tensor(word_positons),
143
+ ]
144
+ if config['Architecture']['algorithm'] == "CAN":
145
+ image_mask = paddle.ones(
146
+ (np.expand_dims(
147
+ batch[0], axis=0).shape), dtype='float32')
148
+ label = paddle.ones((1, 36), dtype='int64')
149
+ images = np.expand_dims(batch[0], axis=0)
150
+ images = paddle.to_tensor(images)
151
+ if config['Architecture']['algorithm'] == "SRN":
152
+ preds = model(images, others)
153
+ elif config['Architecture']['algorithm'] == "SAR":
154
+ preds = model(images, img_metas)
155
+ elif config['Architecture']['algorithm'] == "RobustScanner":
156
+ preds = model(images, img_metas)
157
+ elif config['Architecture']['algorithm'] == "CAN":
158
+ preds = model([images, image_mask, label])
159
+ else:
160
+ preds = model(images)
161
+ post_result = post_process_class(preds)
162
+ info = None
163
+ if isinstance(post_result, dict):
164
+ rec_info = dict()
165
+ for key in post_result:
166
+ if len(post_result[key][0]) >= 2:
167
+ rec_info[key] = {
168
+ "label": post_result[key][0][0],
169
+ "score": float(post_result[key][0][1]),
170
+ }
171
+ info = json.dumps(rec_info, ensure_ascii=False)
172
+ elif isinstance(post_result, list) and isinstance(post_result[0],
173
+ int):
174
+ # for RFLearning CNT branch
175
+ info = str(post_result[0])
176
+ else:
177
+ if len(post_result[0]) >= 2:
178
+ info = post_result[0][0] + "\t" + str(post_result[0][1])
179
+
180
+ if info is not None:
181
+ logger.info("\t result: {}".format(info))
182
+ fout.write(file + "\t" + info + "\n")
183
+ logger.info("success!")
184
+
185
+
186
+ if __name__ == '__main__':
187
+ config, device, logger, vdl_writer = program.preprocess()
188
+ main()
tools/infer_sr.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+ import json
24
+ from PIL import Image
25
+ import cv2
26
+
27
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.insert(0, __dir__)
29
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
30
+
31
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
32
+
33
+ import paddle
34
+
35
+ from ppocr.data import create_operators, transform
36
+ from ppocr.modeling.architectures import build_model
37
+ from ppocr.postprocess import build_post_process
38
+ from ppocr.utils.save_load import load_model
39
+ from ppocr.utils.utility import get_image_file_list
40
+ import tools.program as program
41
+
42
+
43
+ def main():
44
+ global_config = config['Global']
45
+
46
+ # build post process
47
+ post_process_class = build_post_process(config['PostProcess'],
48
+ global_config)
49
+
50
+ # sr transform
51
+ config['Architecture']["Transform"]['infer_mode'] = True
52
+
53
+ model = build_model(config['Architecture'])
54
+
55
+ load_model(config, model)
56
+
57
+ # create data ops
58
+ transforms = []
59
+ for op in config['Eval']['dataset']['transforms']:
60
+ op_name = list(op)[0]
61
+ if 'Label' in op_name:
62
+ continue
63
+ elif op_name in ['SRResize']:
64
+ op[op_name]['infer_mode'] = True
65
+ elif op_name == 'KeepKeys':
66
+ op[op_name]['keep_keys'] = ['img_lr']
67
+ transforms.append(op)
68
+ global_config['infer_mode'] = True
69
+ ops = create_operators(transforms, global_config)
70
+
71
+ save_visual_path = config['Global'].get('save_visual', "infer_result/")
72
+ if not os.path.exists(os.path.dirname(save_visual_path)):
73
+ os.makedirs(os.path.dirname(save_visual_path))
74
+
75
+ model.eval()
76
+ for file in get_image_file_list(config['Global']['infer_img']):
77
+ logger.info("infer_img: {}".format(file))
78
+ img = Image.open(file).convert("RGB")
79
+ data = {'image_lr': img}
80
+ batch = transform(data, ops)
81
+ images = np.expand_dims(batch[0], axis=0)
82
+ images = paddle.to_tensor(images)
83
+
84
+ preds = model(images)
85
+ sr_img = preds["sr_img"][0]
86
+ lr_img = preds["lr_img"][0]
87
+ fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
88
+ fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
89
+ img_name_pure = os.path.split(file)[-1]
90
+ cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
91
+ fm_sr[:, :, ::-1])
92
+ logger.info("The visualized image saved in infer_result/sr_{}".format(
93
+ img_name_pure))
94
+
95
+ logger.info("success!")
96
+
97
+
98
+ if __name__ == '__main__':
99
+ config, device, logger, vdl_writer = program.preprocess()
100
+ main()
tools/infer_table.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+
21
+ import os
22
+ import sys
23
+ import json
24
+
25
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append(__dir__)
27
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
28
+
29
+ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
30
+
31
+ import paddle
32
+ from paddle.jit import to_static
33
+
34
+ from ppocr.data import create_operators, transform
35
+ from ppocr.modeling.architectures import build_model
36
+ from ppocr.postprocess import build_post_process
37
+ from ppocr.utils.save_load import load_model
38
+ from ppocr.utils.utility import get_image_file_list
39
+ from ppocr.utils.visual import draw_rectangle
40
+ from tools.infer.utility import draw_boxes
41
+ import tools.program as program
42
+ import cv2
43
+
44
+
45
+ @paddle.no_grad()
46
+ def main(config, device, logger, vdl_writer):
47
+ global_config = config['Global']
48
+
49
+ # build post process
50
+ post_process_class = build_post_process(config['PostProcess'],
51
+ global_config)
52
+
53
+ # build model
54
+ if hasattr(post_process_class, 'character'):
55
+ config['Architecture']["Head"]['out_channels'] = len(
56
+ getattr(post_process_class, 'character'))
57
+
58
+ model = build_model(config['Architecture'])
59
+ algorithm = config['Architecture']['algorithm']
60
+
61
+ load_model(config, model)
62
+
63
+ # create data ops
64
+ transforms = []
65
+ for op in config['Eval']['dataset']['transforms']:
66
+ op_name = list(op)[0]
67
+ if 'Encode' in op_name:
68
+ continue
69
+ if op_name == 'KeepKeys':
70
+ op[op_name]['keep_keys'] = ['image', 'shape']
71
+ transforms.append(op)
72
+
73
+ global_config['infer_mode'] = True
74
+ ops = create_operators(transforms, global_config)
75
+
76
+ save_res_path = config['Global']['save_res_path']
77
+ os.makedirs(save_res_path, exist_ok=True)
78
+
79
+ model.eval()
80
+ with open(
81
+ os.path.join(save_res_path, 'infer.txt'), mode='w',
82
+ encoding='utf-8') as f_w:
83
+ for file in get_image_file_list(config['Global']['infer_img']):
84
+ logger.info("infer_img: {}".format(file))
85
+ with open(file, 'rb') as f:
86
+ img = f.read()
87
+ data = {'image': img}
88
+ batch = transform(data, ops)
89
+ images = np.expand_dims(batch[0], axis=0)
90
+ shape_list = np.expand_dims(batch[1], axis=0)
91
+
92
+ images = paddle.to_tensor(images)
93
+ preds = model(images)
94
+ post_result = post_process_class(preds, [shape_list])
95
+
96
+ structure_str_list = post_result['structure_batch_list'][0]
97
+ bbox_list = post_result['bbox_batch_list'][0]
98
+ structure_str_list = structure_str_list[0]
99
+ structure_str_list = [
100
+ '<html>', '<body>', '<table>'
101
+ ] + structure_str_list + ['</table>', '</body>', '</html>']
102
+ bbox_list_str = json.dumps(bbox_list.tolist())
103
+
104
+ logger.info("result: {}, {}".format(structure_str_list,
105
+ bbox_list_str))
106
+ f_w.write("result: {}, {}\n".format(structure_str_list,
107
+ bbox_list_str))
108
+
109
+ if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
110
+ img = draw_rectangle(file, bbox_list)
111
+ else:
112
+ img = draw_boxes(cv2.imread(file), bbox_list)
113
+ cv2.imwrite(
114
+ os.path.join(save_res_path, os.path.basename(file)), img)
115
+ logger.info('save result to {}'.format(save_res_path))
116
+ logger.info("success!")
117
+
118
+
119
+ if __name__ == '__main__':
120
+ config, device, logger, vdl_writer = program.preprocess()
121
+ main(config, device, logger, vdl_writer)
tools/program.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import os
20
+ import sys
21
+ import platform
22
+ import yaml
23
+ import time
24
+ import datetime
25
+ import paddle
26
+ import paddle.distributed as dist
27
+ from tqdm import tqdm
28
+ import cv2
29
+ import numpy as np
30
+ from argparse import ArgumentParser, RawDescriptionHelpFormatter
31
+
32
+ from ppocr.utils.stats import TrainingStats
33
+ from ppocr.utils.save_load import save_model
34
+ from ppocr.utils.utility import print_dict, AverageMeter
35
+ from ppocr.utils.logging import get_logger
36
+ from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
37
+ from ppocr.utils import profiler
38
+ from ppocr.data import build_dataloader
39
+
40
+
41
+ class ArgsParser(ArgumentParser):
42
+ def __init__(self):
43
+ super(ArgsParser, self).__init__(
44
+ formatter_class=RawDescriptionHelpFormatter)
45
+ self.add_argument("-c", "--config", help="configuration file to use")
46
+ self.add_argument(
47
+ "-o", "--opt", nargs='+', help="set configuration options")
48
+ self.add_argument(
49
+ '-p',
50
+ '--profiler_options',
51
+ type=str,
52
+ default=None,
53
+ help='The option of profiler, which should be in format ' \
54
+ '\"key1=value1;key2=value2;key3=value3\".'
55
+ )
56
+
57
+ def parse_args(self, argv=None):
58
+ args = super(ArgsParser, self).parse_args(argv)
59
+ assert args.config is not None, \
60
+ "Please specify --config=configure_file_path."
61
+ args.opt = self._parse_opt(args.opt)
62
+ return args
63
+
64
+ def _parse_opt(self, opts):
65
+ config = {}
66
+ if not opts:
67
+ return config
68
+ for s in opts:
69
+ s = s.strip()
70
+ k, v = s.split('=')
71
+ config[k] = yaml.load(v, Loader=yaml.Loader)
72
+ return config
73
+
74
+
75
+ def load_config(file_path):
76
+ """
77
+ Load config from yml/yaml file.
78
+ Args:
79
+ file_path (str): Path of the config file to be loaded.
80
+ Returns: global config
81
+ """
82
+ _, ext = os.path.splitext(file_path)
83
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
84
+ config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
85
+ return config
86
+
87
+
88
+ def merge_config(config, opts):
89
+ """
90
+ Merge config into global config.
91
+ Args:
92
+ config (dict): Config to be merged.
93
+ Returns: global config
94
+ """
95
+ for key, value in opts.items():
96
+ if "." not in key:
97
+ if isinstance(value, dict) and key in config:
98
+ config[key].update(value)
99
+ else:
100
+ config[key] = value
101
+ else:
102
+ sub_keys = key.split('.')
103
+ assert (
104
+ sub_keys[0] in config
105
+ ), "the sub_keys can only be one of global_config: {}, but get: " \
106
+ "{}, please check your running command".format(
107
+ config.keys(), sub_keys[0])
108
+ cur = config[sub_keys[0]]
109
+ for idx, sub_key in enumerate(sub_keys[1:]):
110
+ if idx == len(sub_keys) - 2:
111
+ cur[sub_key] = value
112
+ else:
113
+ cur = cur[sub_key]
114
+ return config
115
+
116
+
117
+ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
118
+ """
119
+ Log error and exit when set use_gpu=true in paddlepaddle
120
+ cpu version.
121
+ """
122
+ err = "Config {} cannot be set as true while your paddle " \
123
+ "is not compiled with {} ! \nPlease try: \n" \
124
+ "\t1. Install paddlepaddle to run model on {} \n" \
125
+ "\t2. Set {} as false in config file to run " \
126
+ "model on CPU"
127
+
128
+ try:
129
+ if use_gpu and use_xpu:
130
+ print("use_xpu and use_gpu can not both be ture.")
131
+ if use_gpu and not paddle.is_compiled_with_cuda():
132
+ print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
133
+ sys.exit(1)
134
+ if use_xpu and not paddle.device.is_compiled_with_xpu():
135
+ print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
136
+ sys.exit(1)
137
+ if use_npu:
138
+ if int(paddle.version.major) != 0 and int(
139
+ paddle.version.major) <= 2 and int(
140
+ paddle.version.minor) <= 4:
141
+ if not paddle.device.is_compiled_with_npu():
142
+ print(err.format("use_npu", "npu", "npu", "use_npu"))
143
+ sys.exit(1)
144
+ # is_compiled_with_npu() has been updated after paddle-2.4
145
+ else:
146
+ if not paddle.device.is_compiled_with_custom_device("npu"):
147
+ print(err.format("use_npu", "npu", "npu", "use_npu"))
148
+ sys.exit(1)
149
+ if use_mlu and not paddle.device.is_compiled_with_mlu():
150
+ print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
151
+ sys.exit(1)
152
+ except Exception as e:
153
+ pass
154
+
155
+
156
+ def to_float32(preds):
157
+ if isinstance(preds, dict):
158
+ for k in preds:
159
+ if isinstance(preds[k], dict) or isinstance(preds[k], list):
160
+ preds[k] = to_float32(preds[k])
161
+ elif isinstance(preds[k], paddle.Tensor):
162
+ preds[k] = preds[k].astype(paddle.float32)
163
+ elif isinstance(preds, list):
164
+ for k in range(len(preds)):
165
+ if isinstance(preds[k], dict):
166
+ preds[k] = to_float32(preds[k])
167
+ elif isinstance(preds[k], list):
168
+ preds[k] = to_float32(preds[k])
169
+ elif isinstance(preds[k], paddle.Tensor):
170
+ preds[k] = preds[k].astype(paddle.float32)
171
+ elif isinstance(preds, paddle.Tensor):
172
+ preds = preds.astype(paddle.float32)
173
+ return preds
174
+
175
+
176
+ def train(config,
177
+ train_dataloader,
178
+ valid_dataloader,
179
+ device,
180
+ model,
181
+ loss_class,
182
+ optimizer,
183
+ lr_scheduler,
184
+ post_process_class,
185
+ eval_class,
186
+ pre_best_model_dict,
187
+ logger,
188
+ log_writer=None,
189
+ scaler=None,
190
+ amp_level='O2',
191
+ amp_custom_black_list=[]):
192
+ cal_metric_during_train = config['Global'].get('cal_metric_during_train',
193
+ False)
194
+ calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
195
+ log_smooth_window = config['Global']['log_smooth_window']
196
+ epoch_num = config['Global']['epoch_num']
197
+ print_batch_step = config['Global']['print_batch_step']
198
+ eval_batch_step = config['Global']['eval_batch_step']
199
+ profiler_options = config['profiler_options']
200
+
201
+ global_step = 0
202
+ if 'global_step' in pre_best_model_dict:
203
+ global_step = pre_best_model_dict['global_step']
204
+ start_eval_step = 0
205
+ if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
206
+ start_eval_step = eval_batch_step[0]
207
+ eval_batch_step = eval_batch_step[1]
208
+ if len(valid_dataloader) == 0:
209
+ logger.info(
210
+ 'No Images in eval dataset, evaluation during training ' \
211
+ 'will be disabled'
212
+ )
213
+ start_eval_step = 1e111
214
+ logger.info(
215
+ "During the training process, after the {}th iteration, " \
216
+ "an evaluation is run every {} iterations".
217
+ format(start_eval_step, eval_batch_step))
218
+ save_epoch_step = config['Global']['save_epoch_step']
219
+ save_model_dir = config['Global']['save_model_dir']
220
+ if not os.path.exists(save_model_dir):
221
+ os.makedirs(save_model_dir)
222
+ main_indicator = eval_class.main_indicator
223
+ best_model_dict = {main_indicator: 0}
224
+ best_model_dict.update(pre_best_model_dict)
225
+ train_stats = TrainingStats(log_smooth_window, ['lr'])
226
+ model_average = False
227
+ model.train()
228
+
229
+ use_srn = config['Architecture']['algorithm'] == "SRN"
230
+ extra_input_models = [
231
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
232
+ "RobustScanner", "RFL", 'DRRG'
233
+ ]
234
+ extra_input = False
235
+ if config['Architecture']['algorithm'] == 'Distillation':
236
+ for key in config['Architecture']["Models"]:
237
+ extra_input = extra_input or config['Architecture']['Models'][key][
238
+ 'algorithm'] in extra_input_models
239
+ else:
240
+ extra_input = config['Architecture']['algorithm'] in extra_input_models
241
+ try:
242
+ model_type = config['Architecture']['model_type']
243
+ except:
244
+ model_type = None
245
+
246
+ algorithm = config['Architecture']['algorithm']
247
+
248
+ start_epoch = best_model_dict[
249
+ 'start_epoch'] if 'start_epoch' in best_model_dict else 1
250
+
251
+ total_samples = 0
252
+ train_reader_cost = 0.0
253
+ train_batch_cost = 0.0
254
+ reader_start = time.time()
255
+ eta_meter = AverageMeter()
256
+
257
+ max_iter = len(train_dataloader) - 1 if platform.system(
258
+ ) == "Windows" else len(train_dataloader)
259
+
260
+ for epoch in range(start_epoch, epoch_num + 1):
261
+ if train_dataloader.dataset.need_reset:
262
+ train_dataloader = build_dataloader(
263
+ config, 'Train', device, logger, seed=epoch)
264
+ max_iter = len(train_dataloader) - 1 if platform.system(
265
+ ) == "Windows" else len(train_dataloader)
266
+
267
+ for idx, batch in enumerate(train_dataloader):
268
+ profiler.add_profiler_step(profiler_options)
269
+ train_reader_cost += time.time() - reader_start
270
+ if idx >= max_iter:
271
+ break
272
+ lr = optimizer.get_lr()
273
+ images = batch[0]
274
+ if use_srn:
275
+ model_average = True
276
+ # use amp
277
+ if scaler:
278
+ with paddle.amp.auto_cast(
279
+ level=amp_level,
280
+ custom_black_list=amp_custom_black_list):
281
+ if model_type == 'table' or extra_input:
282
+ preds = model(images, data=batch[1:])
283
+ elif model_type in ["kie"]:
284
+ preds = model(batch)
285
+ elif algorithm in ['CAN']:
286
+ preds = model(batch[:3])
287
+ else:
288
+ preds = model(images)
289
+ preds = to_float32(preds)
290
+ loss = loss_class(preds, batch)
291
+ avg_loss = loss['loss']
292
+ scaled_avg_loss = scaler.scale(avg_loss)
293
+ scaled_avg_loss.backward()
294
+ scaler.minimize(optimizer, scaled_avg_loss)
295
+ else:
296
+ if model_type == 'table' or extra_input:
297
+ preds = model(images, data=batch[1:])
298
+ elif model_type in ["kie", 'sr']:
299
+ preds = model(batch)
300
+ elif algorithm in ['CAN']:
301
+ preds = model(batch[:3])
302
+ else:
303
+ preds = model(images)
304
+ loss = loss_class(preds, batch)
305
+ avg_loss = loss['loss']
306
+ avg_loss.backward()
307
+ optimizer.step()
308
+
309
+ optimizer.clear_grad()
310
+
311
+ if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
312
+ batch = [item.numpy() for item in batch]
313
+ if model_type in ['kie', 'sr']:
314
+ eval_class(preds, batch)
315
+ elif model_type in ['table']:
316
+ post_result = post_process_class(preds, batch)
317
+ eval_class(post_result, batch)
318
+ elif algorithm in ['CAN']:
319
+ model_type = 'can'
320
+ eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
321
+ else:
322
+ if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
323
+ ]: # for multi head loss
324
+ post_result = post_process_class(
325
+ preds['ctc'], batch[1]) # for CTC head out
326
+ elif config['Loss']['name'] in ['VLLoss']:
327
+ post_result = post_process_class(preds, batch[1],
328
+ batch[-1])
329
+ else:
330
+ post_result = post_process_class(preds, batch[1])
331
+ eval_class(post_result, batch)
332
+ metric = eval_class.get_metric()
333
+ train_stats.update(metric)
334
+
335
+ train_batch_time = time.time() - reader_start
336
+ train_batch_cost += train_batch_time
337
+ eta_meter.update(train_batch_time)
338
+ global_step += 1
339
+ total_samples += len(images)
340
+
341
+ if not isinstance(lr_scheduler, float):
342
+ lr_scheduler.step()
343
+
344
+ # logger and visualdl
345
+ stats = {k: v.numpy().mean() for k, v in loss.items()}
346
+ stats['lr'] = lr
347
+ train_stats.update(stats)
348
+
349
+ if log_writer is not None and dist.get_rank() == 0:
350
+ log_writer.log_metrics(
351
+ metrics=train_stats.get(), prefix="TRAIN", step=global_step)
352
+
353
+ if dist.get_rank() == 0 and (
354
+ (global_step > 0 and global_step % print_batch_step == 0) or
355
+ (idx >= len(train_dataloader) - 1)):
356
+ logs = train_stats.log()
357
+
358
+ eta_sec = ((epoch_num + 1 - epoch) * \
359
+ len(train_dataloader) - idx - 1) * eta_meter.avg
360
+ eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
361
+ strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
362
+ '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
363
+ 'ips: {:.5f} samples/s, eta: {}'.format(
364
+ epoch, epoch_num, global_step, logs,
365
+ train_reader_cost / print_batch_step,
366
+ train_batch_cost / print_batch_step,
367
+ total_samples / print_batch_step,
368
+ total_samples / train_batch_cost, eta_sec_format)
369
+ logger.info(strs)
370
+
371
+ total_samples = 0
372
+ train_reader_cost = 0.0
373
+ train_batch_cost = 0.0
374
+ # eval
375
+ if global_step > start_eval_step and \
376
+ (global_step - start_eval_step) % eval_batch_step == 0 \
377
+ and dist.get_rank() == 0:
378
+ if model_average:
379
+ Model_Average = paddle.incubate.optimizer.ModelAverage(
380
+ 0.15,
381
+ parameters=model.parameters(),
382
+ min_average_window=10000,
383
+ max_average_window=15625)
384
+ Model_Average.apply()
385
+ cur_metric = eval(
386
+ model,
387
+ valid_dataloader,
388
+ post_process_class,
389
+ eval_class,
390
+ model_type,
391
+ extra_input=extra_input,
392
+ scaler=scaler,
393
+ amp_level=amp_level,
394
+ amp_custom_black_list=amp_custom_black_list)
395
+ cur_metric_str = 'cur metric, {}'.format(', '.join(
396
+ ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
397
+ logger.info(cur_metric_str)
398
+
399
+ # logger metric
400
+ if log_writer is not None:
401
+ log_writer.log_metrics(
402
+ metrics=cur_metric, prefix="EVAL", step=global_step)
403
+
404
+ if cur_metric[main_indicator] >= best_model_dict[
405
+ main_indicator]:
406
+ best_model_dict.update(cur_metric)
407
+ best_model_dict['best_epoch'] = epoch
408
+ save_model(
409
+ model,
410
+ optimizer,
411
+ save_model_dir,
412
+ logger,
413
+ config,
414
+ is_best=True,
415
+ prefix='best_accuracy',
416
+ best_model_dict=best_model_dict,
417
+ epoch=epoch,
418
+ global_step=global_step)
419
+ best_str = 'best metric, {}'.format(', '.join([
420
+ '{}: {}'.format(k, v) for k, v in best_model_dict.items()
421
+ ]))
422
+ logger.info(best_str)
423
+ # logger best metric
424
+ if log_writer is not None:
425
+ log_writer.log_metrics(
426
+ metrics={
427
+ "best_{}".format(main_indicator):
428
+ best_model_dict[main_indicator]
429
+ },
430
+ prefix="EVAL",
431
+ step=global_step)
432
+
433
+ log_writer.log_model(
434
+ is_best=True,
435
+ prefix="best_accuracy",
436
+ metadata=best_model_dict)
437
+
438
+ reader_start = time.time()
439
+ if dist.get_rank() == 0:
440
+ save_model(
441
+ model,
442
+ optimizer,
443
+ save_model_dir,
444
+ logger,
445
+ config,
446
+ is_best=False,
447
+ prefix='latest',
448
+ best_model_dict=best_model_dict,
449
+ epoch=epoch,
450
+ global_step=global_step)
451
+
452
+ if log_writer is not None:
453
+ log_writer.log_model(is_best=False, prefix="latest")
454
+
455
+ if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
456
+ save_model(
457
+ model,
458
+ optimizer,
459
+ save_model_dir,
460
+ logger,
461
+ config,
462
+ is_best=False,
463
+ prefix='iter_epoch_{}'.format(epoch),
464
+ best_model_dict=best_model_dict,
465
+ epoch=epoch,
466
+ global_step=global_step)
467
+ if log_writer is not None:
468
+ log_writer.log_model(
469
+ is_best=False, prefix='iter_epoch_{}'.format(epoch))
470
+
471
+ best_str = 'best metric, {}'.format(', '.join(
472
+ ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
473
+ logger.info(best_str)
474
+ if dist.get_rank() == 0 and log_writer is not None:
475
+ log_writer.close()
476
+ return
477
+
478
+
479
+ def eval(model,
480
+ valid_dataloader,
481
+ post_process_class,
482
+ eval_class,
483
+ model_type=None,
484
+ extra_input=False,
485
+ scaler=None,
486
+ amp_level='O2',
487
+ amp_custom_black_list=[]):
488
+ model.eval()
489
+ with paddle.no_grad():
490
+ total_frame = 0.0
491
+ total_time = 0.0
492
+ pbar = tqdm(
493
+ total=len(valid_dataloader),
494
+ desc='eval model:',
495
+ position=0,
496
+ leave=True)
497
+ max_iter = len(valid_dataloader) - 1 if platform.system(
498
+ ) == "Windows" else len(valid_dataloader)
499
+ sum_images = 0
500
+ for idx, batch in enumerate(valid_dataloader):
501
+ if idx >= max_iter:
502
+ break
503
+ images = batch[0]
504
+ start = time.time()
505
+
506
+ # use amp
507
+ if scaler:
508
+ with paddle.amp.auto_cast(
509
+ level=amp_level,
510
+ custom_black_list=amp_custom_black_list):
511
+ if model_type == 'table' or extra_input:
512
+ preds = model(images, data=batch[1:])
513
+ elif model_type in ["kie"]:
514
+ preds = model(batch)
515
+ elif model_type in ['can']:
516
+ preds = model(batch[:3])
517
+ elif model_type in ['sr']:
518
+ preds = model(batch)
519
+ sr_img = preds["sr_img"]
520
+ lr_img = preds["lr_img"]
521
+ else:
522
+ preds = model(images)
523
+ preds = to_float32(preds)
524
+ else:
525
+ if model_type == 'table' or extra_input:
526
+ preds = model(images, data=batch[1:])
527
+ elif model_type in ["kie"]:
528
+ preds = model(batch)
529
+ elif model_type in ['can']:
530
+ preds = model(batch[:3])
531
+ elif model_type in ['sr']:
532
+ preds = model(batch)
533
+ sr_img = preds["sr_img"]
534
+ lr_img = preds["lr_img"]
535
+ else:
536
+ preds = model(images)
537
+
538
+ batch_numpy = []
539
+ for item in batch:
540
+ if isinstance(item, paddle.Tensor):
541
+ batch_numpy.append(item.numpy())
542
+ else:
543
+ batch_numpy.append(item)
544
+ # Obtain usable results from post-processing methods
545
+ total_time += time.time() - start
546
+ # Evaluate the results of the current batch
547
+ if model_type in ['table', 'kie']:
548
+ if post_process_class is None:
549
+ eval_class(preds, batch_numpy)
550
+ else:
551
+ post_result = post_process_class(preds, batch_numpy)
552
+ eval_class(post_result, batch_numpy)
553
+ elif model_type in ['sr']:
554
+ eval_class(preds, batch_numpy)
555
+ elif model_type in ['can']:
556
+ eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
557
+ else:
558
+ post_result = post_process_class(preds, batch_numpy[1])
559
+ eval_class(post_result, batch_numpy)
560
+
561
+ pbar.update(1)
562
+ total_frame += len(images)
563
+ sum_images += 1
564
+ # Get final metric,eg. acc or hmean
565
+ metric = eval_class.get_metric()
566
+
567
+ pbar.close()
568
+ model.train()
569
+ metric['fps'] = total_frame / total_time
570
+ return metric
571
+
572
+
573
+ def update_center(char_center, post_result, preds):
574
+ result, label = post_result
575
+ feats, logits = preds
576
+ logits = paddle.argmax(logits, axis=-1)
577
+ feats = feats.numpy()
578
+ logits = logits.numpy()
579
+
580
+ for idx_sample in range(len(label)):
581
+ if result[idx_sample][0] == label[idx_sample][0]:
582
+ feat = feats[idx_sample]
583
+ logit = logits[idx_sample]
584
+ for idx_time in range(len(logit)):
585
+ index = logit[idx_time]
586
+ if index in char_center.keys():
587
+ char_center[index][0] = (
588
+ char_center[index][0] * char_center[index][1] +
589
+ feat[idx_time]) / (char_center[index][1] + 1)
590
+ char_center[index][1] += 1
591
+ else:
592
+ char_center[index] = [feat[idx_time], 1]
593
+ return char_center
594
+
595
+
596
+ def get_center(model, eval_dataloader, post_process_class):
597
+ pbar = tqdm(total=len(eval_dataloader), desc='get center:')
598
+ max_iter = len(eval_dataloader) - 1 if platform.system(
599
+ ) == "Windows" else len(eval_dataloader)
600
+ char_center = dict()
601
+ for idx, batch in enumerate(eval_dataloader):
602
+ if idx >= max_iter:
603
+ break
604
+ images = batch[0]
605
+ start = time.time()
606
+ preds = model(images)
607
+
608
+ batch = [item.numpy() for item in batch]
609
+ # Obtain usable results from post-processing methods
610
+ post_result = post_process_class(preds, batch[1])
611
+
612
+ #update char_center
613
+ char_center = update_center(char_center, post_result, preds)
614
+ pbar.update(1)
615
+
616
+ pbar.close()
617
+ for key in char_center.keys():
618
+ char_center[key] = char_center[key][0]
619
+ return char_center
620
+
621
+
622
+ def preprocess(is_train=False):
623
+ FLAGS = ArgsParser().parse_args()
624
+ profiler_options = FLAGS.profiler_options
625
+ config = load_config(FLAGS.config)
626
+ config = merge_config(config, FLAGS.opt)
627
+ profile_dic = {"profiler_options": FLAGS.profiler_options}
628
+ config = merge_config(config, profile_dic)
629
+
630
+ if is_train:
631
+ # save_config
632
+ save_model_dir = config['Global']['save_model_dir']
633
+ os.makedirs(save_model_dir, exist_ok=True)
634
+ with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
635
+ yaml.dump(
636
+ dict(config), f, default_flow_style=False, sort_keys=False)
637
+ log_file = '{}/train.log'.format(save_model_dir)
638
+ else:
639
+ log_file = None
640
+ logger = get_logger(log_file=log_file)
641
+
642
+ # check if set use_gpu=True in paddlepaddle cpu version
643
+ use_gpu = config['Global'].get('use_gpu', False)
644
+ use_xpu = config['Global'].get('use_xpu', False)
645
+ use_npu = config['Global'].get('use_npu', False)
646
+ use_mlu = config['Global'].get('use_mlu', False)
647
+
648
+ alg = config['Architecture']['algorithm']
649
+ assert alg in [
650
+ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
651
+ 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
652
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
653
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
654
+ 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN',
655
+ 'Telescope'
656
+ ]
657
+
658
+ if use_xpu:
659
+ device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
660
+ elif use_npu:
661
+ device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
662
+ elif use_mlu:
663
+ device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
664
+ else:
665
+ device = 'gpu:{}'.format(dist.ParallelEnv()
666
+ .dev_id) if use_gpu else 'cpu'
667
+ check_device(use_gpu, use_xpu, use_npu, use_mlu)
668
+
669
+ device = paddle.set_device(device)
670
+
671
+ config['Global']['distributed'] = dist.get_world_size() != 1
672
+
673
+ loggers = []
674
+
675
+ if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
676
+ save_model_dir = config['Global']['save_model_dir']
677
+ vdl_writer_path = '{}/vdl/'.format(save_model_dir)
678
+ log_writer = VDLLogger(vdl_writer_path)
679
+ loggers.append(log_writer)
680
+ if ('use_wandb' in config['Global'] and
681
+ config['Global']['use_wandb']) or 'wandb' in config:
682
+ save_dir = config['Global']['save_model_dir']
683
+ wandb_writer_path = "{}/wandb".format(save_dir)
684
+ if "wandb" in config:
685
+ wandb_params = config['wandb']
686
+ else:
687
+ wandb_params = dict()
688
+ wandb_params.update({'save_dir': save_dir})
689
+ log_writer = WandbLogger(**wandb_params, config=config)
690
+ loggers.append(log_writer)
691
+ else:
692
+ log_writer = None
693
+ print_dict(config, logger)
694
+
695
+ if loggers:
696
+ log_writer = Loggers(loggers)
697
+ else:
698
+ log_writer = None
699
+
700
+ logger.info('train with paddle {} and device {}'.format(paddle.__version__,
701
+ device))
702
+ return config, device, logger, log_writer
tools/test_hubserving.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
17
+ sys.path.append(__dir__)
18
+ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
19
+
20
+ from ppocr.utils.logging import get_logger
21
+ logger = get_logger()
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import time
26
+ from PIL import Image
27
+ from ppocr.utils.utility import get_image_file_list
28
+ from tools.infer.utility import draw_ocr, draw_boxes, str2bool
29
+ from ppstructure.utility import draw_structure_result
30
+ from ppstructure.predict_system import to_excel
31
+
32
+ import requests
33
+ import json
34
+ import base64
35
+
36
+
37
+ def cv2_to_base64(image):
38
+ return base64.b64encode(image).decode('utf8')
39
+
40
+
41
+ def draw_server_result(image_file, res):
42
+ img = cv2.imread(image_file)
43
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
44
+ if len(res) == 0:
45
+ return np.array(image)
46
+ keys = res[0].keys()
47
+ if 'text_region' not in keys: # for ocr_rec, draw function is invalid
48
+ logger.info("draw function is invalid for ocr_rec!")
49
+ return None
50
+ elif 'text' not in keys: # for ocr_det
51
+ logger.info("draw text boxes only!")
52
+ boxes = []
53
+ for dno in range(len(res)):
54
+ boxes.append(res[dno]['text_region'])
55
+ boxes = np.array(boxes)
56
+ draw_img = draw_boxes(image, boxes)
57
+ return draw_img
58
+ else: # for ocr_system
59
+ logger.info("draw boxes and texts!")
60
+ boxes = []
61
+ texts = []
62
+ scores = []
63
+ for dno in range(len(res)):
64
+ boxes.append(res[dno]['text_region'])
65
+ texts.append(res[dno]['text'])
66
+ scores.append(res[dno]['confidence'])
67
+ boxes = np.array(boxes)
68
+ scores = np.array(scores)
69
+ draw_img = draw_ocr(
70
+ image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
71
+ return draw_img
72
+
73
+
74
+ def save_structure_res(res, save_folder, image_file):
75
+ img = cv2.imread(image_file)
76
+ excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
77
+ os.makedirs(excel_save_folder, exist_ok=True)
78
+ # save res
79
+ with open(
80
+ os.path.join(excel_save_folder, 'res.txt'), 'w',
81
+ encoding='utf8') as f:
82
+ for region in res:
83
+ if region['type'] == 'Table':
84
+ excel_path = os.path.join(excel_save_folder,
85
+ '{}.xlsx'.format(region['bbox']))
86
+ to_excel(region['res'], excel_path)
87
+ elif region['type'] == 'Figure':
88
+ x1, y1, x2, y2 = region['bbox']
89
+ print(region['bbox'])
90
+ roi_img = img[y1:y2, x1:x2, :]
91
+ img_path = os.path.join(excel_save_folder,
92
+ '{}.jpg'.format(region['bbox']))
93
+ cv2.imwrite(img_path, roi_img)
94
+ else:
95
+ for text_result in region['res']:
96
+ f.write('{}\n'.format(json.dumps(text_result)))
97
+
98
+
99
+ def main(args):
100
+ image_file_list = get_image_file_list(args.image_dir)
101
+ is_visualize = False
102
+ headers = {"Content-type": "application/json"}
103
+ cnt = 0
104
+ total_time = 0
105
+ for image_file in image_file_list:
106
+ img = open(image_file, 'rb').read()
107
+ if img is None:
108
+ logger.info("error in loading image:{}".format(image_file))
109
+ continue
110
+ img_name = os.path.basename(image_file)
111
+ # seed http request
112
+ starttime = time.time()
113
+ data = {'images': [cv2_to_base64(img)]}
114
+ r = requests.post(
115
+ url=args.server_url, headers=headers, data=json.dumps(data))
116
+ elapse = time.time() - starttime
117
+ total_time += elapse
118
+ logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
119
+ res = r.json()["results"][0]
120
+ logger.info(res)
121
+
122
+ if args.visualize:
123
+ draw_img = None
124
+ if 'structure_table' in args.server_url:
125
+ to_excel(res['html'], './{}.xlsx'.format(img_name))
126
+ elif 'structure_system' in args.server_url:
127
+ save_structure_res(res['regions'], args.output, image_file)
128
+ else:
129
+ draw_img = draw_server_result(image_file, res)
130
+ if draw_img is not None:
131
+ if not os.path.exists(args.output):
132
+ os.makedirs(args.output)
133
+ cv2.imwrite(
134
+ os.path.join(args.output, os.path.basename(image_file)),
135
+ draw_img[:, :, ::-1])
136
+ logger.info("The visualized image saved in {}".format(
137
+ os.path.join(args.output, os.path.basename(image_file))))
138
+ cnt += 1
139
+ if cnt % 100 == 0:
140
+ logger.info("{} processed".format(cnt))
141
+ logger.info("avg time cost: {}".format(float(total_time) / cnt))
142
+
143
+
144
+ def parse_args():
145
+ import argparse
146
+ parser = argparse.ArgumentParser(description="args for hub serving")
147
+ parser.add_argument("--server_url", type=str, required=True)
148
+ parser.add_argument("--image_dir", type=str, required=True)
149
+ parser.add_argument("--visualize", type=str2bool, default=False)
150
+ parser.add_argument("--output", type=str, default='./hubserving_result')
151
+ args = parser.parse_args()
152
+ return args
153
+
154
+
155
+ if __name__ == '__main__':
156
+ args = parse_args()
157
+ main(args)
tools/train.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import os
20
+ import sys
21
+
22
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
23
+ sys.path.append(__dir__)
24
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
25
+
26
+ import yaml
27
+ import paddle
28
+ import paddle.distributed as dist
29
+
30
+ from ppocr.data import build_dataloader
31
+ from ppocr.modeling.architectures import build_model
32
+ from ppocr.losses import build_loss
33
+ from ppocr.optimizer import build_optimizer
34
+ from ppocr.postprocess import build_post_process
35
+ from ppocr.metrics import build_metric
36
+ from ppocr.utils.save_load import load_model
37
+ from ppocr.utils.utility import set_seed
38
+ from ppocr.modeling.architectures import apply_to_static
39
+ import tools.program as program
40
+
41
+ dist.get_world_size()
42
+
43
+
44
+ def main(config, device, logger, vdl_writer):
45
+ # init dist environment
46
+ if config['Global']['distributed']:
47
+ dist.init_parallel_env()
48
+
49
+ global_config = config['Global']
50
+
51
+ # build dataloader
52
+ train_dataloader = build_dataloader(config, 'Train', device, logger)
53
+ if len(train_dataloader) == 0:
54
+ logger.error(
55
+ "No Images in train dataset, please ensure\n" +
56
+ "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
57
+ +
58
+ "\t2. The annotation file and path in the configuration file are provided normally."
59
+ )
60
+ return
61
+
62
+ if config['Eval']:
63
+ valid_dataloader = build_dataloader(config, 'Eval', device, logger)
64
+ else:
65
+ valid_dataloader = None
66
+
67
+ # build post process
68
+ post_process_class = build_post_process(config['PostProcess'],
69
+ global_config)
70
+
71
+ # build model
72
+ # for rec algorithm
73
+ if hasattr(post_process_class, 'character'):
74
+ char_num = len(getattr(post_process_class, 'character'))
75
+ if config['Architecture']["algorithm"] in ["Distillation",
76
+ ]: # distillation model
77
+ for key in config['Architecture']["Models"]:
78
+ if config['Architecture']['Models'][key]['Head'][
79
+ 'name'] == 'MultiHead': # for multi head
80
+ if config['PostProcess'][
81
+ 'name'] == 'DistillationSARLabelDecode':
82
+ char_num = char_num - 2
83
+ # update SARLoss params
84
+ assert list(config['Loss']['loss_config_list'][-1].keys())[
85
+ 0] == 'DistillationSARLoss'
86
+ config['Loss']['loss_config_list'][-1][
87
+ 'DistillationSARLoss']['ignore_index'] = char_num + 1
88
+ out_channels_list = {}
89
+ out_channels_list['CTCLabelDecode'] = char_num
90
+ out_channels_list['SARLabelDecode'] = char_num + 2
91
+ config['Architecture']['Models'][key]['Head'][
92
+ 'out_channels_list'] = out_channels_list
93
+ else:
94
+ config['Architecture']["Models"][key]["Head"][
95
+ 'out_channels'] = char_num
96
+ elif config['Architecture']['Head'][
97
+ 'name'] == 'MultiHead': # for multi head
98
+ if config['PostProcess']['name'] == 'SARLabelDecode':
99
+ char_num = char_num - 2
100
+ # update SARLoss params
101
+ assert list(config['Loss']['loss_config_list'][1].keys())[
102
+ 0] == 'SARLoss'
103
+ if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
104
+ config['Loss']['loss_config_list'][1]['SARLoss'] = {
105
+ 'ignore_index': char_num + 1
106
+ }
107
+ else:
108
+ config['Loss']['loss_config_list'][1]['SARLoss'][
109
+ 'ignore_index'] = char_num + 1
110
+ out_channels_list = {}
111
+ out_channels_list['CTCLabelDecode'] = char_num
112
+ out_channels_list['SARLabelDecode'] = char_num + 2
113
+ config['Architecture']['Head'][
114
+ 'out_channels_list'] = out_channels_list
115
+ else: # base rec model
116
+ config['Architecture']["Head"]['out_channels'] = char_num
117
+
118
+ if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
119
+ config['Loss']['ignore_index'] = char_num - 1
120
+
121
+ model = build_model(config['Architecture'])
122
+
123
+ use_sync_bn = config["Global"].get("use_sync_bn", False)
124
+ if use_sync_bn:
125
+ model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
126
+ logger.info('convert_sync_batchnorm')
127
+
128
+ model = apply_to_static(model, config, logger)
129
+
130
+ # build loss
131
+ loss_class = build_loss(config['Loss'])
132
+
133
+ # build optim
134
+ optimizer, lr_scheduler = build_optimizer(
135
+ config['Optimizer'],
136
+ epochs=config['Global']['epoch_num'],
137
+ step_each_epoch=len(train_dataloader),
138
+ model=model)
139
+
140
+ # build metric
141
+ eval_class = build_metric(config['Metric'])
142
+
143
+ logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
144
+ if valid_dataloader is not None:
145
+ logger.info('valid dataloader has {} iters'.format(
146
+ len(valid_dataloader)))
147
+
148
+ use_amp = config["Global"].get("use_amp", False)
149
+ amp_level = config["Global"].get("amp_level", 'O2')
150
+ amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
151
+ if use_amp:
152
+ AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
153
+ if paddle.is_compiled_with_cuda():
154
+ AMP_RELATED_FLAGS_SETTING.update({
155
+ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1
156
+ })
157
+ paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
158
+ scale_loss = config["Global"].get("scale_loss", 1.0)
159
+ use_dynamic_loss_scaling = config["Global"].get(
160
+ "use_dynamic_loss_scaling", False)
161
+ scaler = paddle.amp.GradScaler(
162
+ init_loss_scaling=scale_loss,
163
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling)
164
+ if amp_level == "O2":
165
+ model, optimizer = paddle.amp.decorate(
166
+ models=model,
167
+ optimizers=optimizer,
168
+ level=amp_level,
169
+ master_weight=True)
170
+ else:
171
+ scaler = None
172
+
173
+ # load pretrain model
174
+ pre_best_model_dict = load_model(config, model, optimizer,
175
+ config['Architecture']["model_type"])
176
+
177
+ if config['Global']['distributed']:
178
+ model = paddle.DataParallel(model)
179
+ # start train
180
+ program.train(config, train_dataloader, valid_dataloader, device, model,
181
+ loss_class, optimizer, lr_scheduler, post_process_class,
182
+ eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
183
+ amp_level, amp_custom_black_list)
184
+
185
+
186
+ def test_reader(config, device, logger):
187
+ loader = build_dataloader(config, 'Train', device, logger)
188
+ import time
189
+ starttime = time.time()
190
+ count = 0
191
+ try:
192
+ for data in loader():
193
+ count += 1
194
+ if count % 1 == 0:
195
+ batch_time = time.time() - starttime
196
+ starttime = time.time()
197
+ logger.info("reader: {}, {}, {}".format(
198
+ count, len(data[0]), batch_time))
199
+ except Exception as e:
200
+ logger.info(e)
201
+ logger.info("finish reader: {}, Success!".format(count))
202
+
203
+
204
+ if __name__ == '__main__':
205
+ config, device, logger, vdl_writer = program.preprocess(is_train=True)
206
+ seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
207
+ set_seed(seed)
208
+ main(config, device, logger, vdl_writer)
209
+ # test_reader(config, device, logger)