Spaces:
Build error
Build error
DeepLearning101
commited on
Commit
•
fa6fa48
1
Parent(s):
c93900f
Upload 39 files
Browse files- tools/__init__.py +14 -0
- tools/__pycache__/__init__.cpython-37.pyc +0 -0
- tools/__pycache__/__init__.cpython-38.pyc +0 -0
- tools/end2end/convert_ppocr_label.py +100 -0
- tools/end2end/draw_html.py +73 -0
- tools/end2end/eval_end2end.py +193 -0
- tools/end2end/readme.md +63 -0
- tools/eval.py +137 -0
- tools/export_center.py +77 -0
- tools/export_model.py +269 -0
- tools/infer/__pycache__/predict_cls.cpython-37.pyc +0 -0
- tools/infer/__pycache__/predict_cls.cpython-38.pyc +0 -0
- tools/infer/__pycache__/predict_det.cpython-37.pyc +0 -0
- tools/infer/__pycache__/predict_det.cpython-38.pyc +0 -0
- tools/infer/__pycache__/predict_rec.cpython-37.pyc +0 -0
- tools/infer/__pycache__/predict_rec.cpython-38.pyc +0 -0
- tools/infer/__pycache__/predict_system.cpython-37.pyc +0 -0
- tools/infer/__pycache__/predict_system.cpython-38.pyc +0 -0
- tools/infer/__pycache__/utility.cpython-37.pyc +0 -0
- tools/infer/__pycache__/utility.cpython-38.pyc +0 -0
- tools/infer/predict_cls.py +151 -0
- tools/infer/predict_det.py +353 -0
- tools/infer/predict_e2e.py +169 -0
- tools/infer/predict_rec.py +667 -0
- tools/infer/predict_sr.py +155 -0
- tools/infer/predict_system.py +262 -0
- tools/infer/utility.py +663 -0
- tools/infer_cls.py +85 -0
- tools/infer_det.py +134 -0
- tools/infer_e2e.py +174 -0
- tools/infer_kie.py +176 -0
- tools/infer_kie_token_ser.py +157 -0
- tools/infer_kie_token_ser_re.py +225 -0
- tools/infer_rec.py +188 -0
- tools/infer_sr.py +100 -0
- tools/infer_table.py +121 -0
- tools/program.py +702 -0
- tools/test_hubserving.py +157 -0
- tools/train.py +209 -0
tools/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
tools/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (107 Bytes). View file
|
|
tools/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (144 Bytes). View file
|
|
tools/end2end/convert_ppocr_label.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
def poly_to_string(poly):
|
20 |
+
if len(poly.shape) > 1:
|
21 |
+
poly = np.array(poly).flatten()
|
22 |
+
|
23 |
+
string = "\t".join(str(i) for i in poly)
|
24 |
+
return string
|
25 |
+
|
26 |
+
|
27 |
+
def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
|
28 |
+
if not os.path.exists(label_dir):
|
29 |
+
raise ValueError(f"The file {label_dir} does not exist!")
|
30 |
+
|
31 |
+
assert label_dir != save_dir, "hahahhaha"
|
32 |
+
|
33 |
+
label_file = open(label_dir, 'r')
|
34 |
+
data = label_file.readlines()
|
35 |
+
|
36 |
+
gt_dict = {}
|
37 |
+
|
38 |
+
for line in data:
|
39 |
+
try:
|
40 |
+
tmp = line.split('\t')
|
41 |
+
assert len(tmp) == 2, ""
|
42 |
+
except:
|
43 |
+
tmp = line.strip().split(' ')
|
44 |
+
|
45 |
+
gt_lists = []
|
46 |
+
|
47 |
+
if tmp[0].split('/')[0] is not None:
|
48 |
+
img_path = tmp[0]
|
49 |
+
anno = json.loads(tmp[1])
|
50 |
+
gt_collect = []
|
51 |
+
for dic in anno:
|
52 |
+
#txt = dic['transcription'].replace(' ', '') # ignore blank
|
53 |
+
txt = dic['transcription']
|
54 |
+
if 'score' in dic and float(dic['score']) < 0.5:
|
55 |
+
continue
|
56 |
+
if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
|
57 |
+
#while ' ' in txt:
|
58 |
+
# txt = txt.replace(' ', '')
|
59 |
+
poly = np.array(dic['points']).flatten()
|
60 |
+
if txt == "###":
|
61 |
+
txt_tag = 1 ## ignore 1
|
62 |
+
else:
|
63 |
+
txt_tag = 0
|
64 |
+
if mode == "gt":
|
65 |
+
gt_label = poly_to_string(poly) + "\t" + str(
|
66 |
+
txt_tag) + "\t" + txt + "\n"
|
67 |
+
else:
|
68 |
+
gt_label = poly_to_string(poly) + "\t" + txt + "\n"
|
69 |
+
|
70 |
+
gt_lists.append(gt_label)
|
71 |
+
|
72 |
+
gt_dict[img_path] = gt_lists
|
73 |
+
else:
|
74 |
+
continue
|
75 |
+
|
76 |
+
if not os.path.exists(save_dir):
|
77 |
+
os.makedirs(save_dir)
|
78 |
+
|
79 |
+
for img_name in gt_dict.keys():
|
80 |
+
save_name = img_name.split("/")[-1]
|
81 |
+
save_file = os.path.join(save_dir, save_name + ".txt")
|
82 |
+
with open(save_file, "w") as f:
|
83 |
+
f.writelines(gt_dict[img_name])
|
84 |
+
|
85 |
+
print("The convert label saved in {}".format(save_dir))
|
86 |
+
|
87 |
+
|
88 |
+
def parse_args():
|
89 |
+
import argparse
|
90 |
+
parser = argparse.ArgumentParser(description="args")
|
91 |
+
parser.add_argument("--label_path", type=str, required=True)
|
92 |
+
parser.add_argument("--save_folder", type=str, required=True)
|
93 |
+
parser.add_argument("--mode", type=str, default=False)
|
94 |
+
args = parser.parse_args()
|
95 |
+
return args
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
args = parse_args()
|
100 |
+
convert_label(args.label_path, args.mode, args.save_folder)
|
tools/end2end/draw_html.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
|
19 |
+
def str2bool(v):
|
20 |
+
return v.lower() in ("true", "t", "1")
|
21 |
+
|
22 |
+
|
23 |
+
def init_args():
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--image_dir", type=str, default="")
|
26 |
+
parser.add_argument("--save_html_path", type=str, default="./default.html")
|
27 |
+
parser.add_argument("--width", type=int, default=640)
|
28 |
+
return parser
|
29 |
+
|
30 |
+
|
31 |
+
def parse_args():
|
32 |
+
parser = init_args()
|
33 |
+
return parser.parse_args()
|
34 |
+
|
35 |
+
|
36 |
+
def draw_debug_img(args):
|
37 |
+
|
38 |
+
html_path = args.save_html_path
|
39 |
+
|
40 |
+
err_cnt = 0
|
41 |
+
with open(html_path, 'w') as html:
|
42 |
+
html.write('<html>\n<body>\n')
|
43 |
+
html.write('<table border="1">\n')
|
44 |
+
html.write(
|
45 |
+
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
|
46 |
+
)
|
47 |
+
image_list = []
|
48 |
+
path = args.image_dir
|
49 |
+
for i, filename in enumerate(sorted(os.listdir(path))):
|
50 |
+
if filename.endswith("txt"): continue
|
51 |
+
# The image path
|
52 |
+
base = "{}/{}".format(path, filename)
|
53 |
+
html.write("<tr>\n")
|
54 |
+
html.write(f'<td> {filename}\n GT')
|
55 |
+
html.write(f'<td>GT\n<img src="{base}" width={args.width}></td>')
|
56 |
+
|
57 |
+
html.write("</tr>\n")
|
58 |
+
html.write('<style>\n')
|
59 |
+
html.write('span {\n')
|
60 |
+
html.write(' color: red;\n')
|
61 |
+
html.write('}\n')
|
62 |
+
html.write('</style>\n')
|
63 |
+
html.write('</table>\n')
|
64 |
+
html.write('</html>\n</body>\n')
|
65 |
+
print(f"The html file saved in {html_path}")
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
|
71 |
+
args = parse_args()
|
72 |
+
|
73 |
+
draw_debug_img(args)
|
tools/end2end/eval_end2end.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import re
|
17 |
+
import sys
|
18 |
+
import shapely
|
19 |
+
from shapely.geometry import Polygon
|
20 |
+
import numpy as np
|
21 |
+
from collections import defaultdict
|
22 |
+
import operator
|
23 |
+
import editdistance
|
24 |
+
|
25 |
+
|
26 |
+
def strQ2B(ustring):
|
27 |
+
rstring = ""
|
28 |
+
for uchar in ustring:
|
29 |
+
inside_code = ord(uchar)
|
30 |
+
if inside_code == 12288:
|
31 |
+
inside_code = 32
|
32 |
+
elif (inside_code >= 65281 and inside_code <= 65374):
|
33 |
+
inside_code -= 65248
|
34 |
+
rstring += chr(inside_code)
|
35 |
+
return rstring
|
36 |
+
|
37 |
+
|
38 |
+
def polygon_from_str(polygon_points):
|
39 |
+
"""
|
40 |
+
Create a shapely polygon object from gt or dt line.
|
41 |
+
"""
|
42 |
+
polygon_points = np.array(polygon_points).reshape(4, 2)
|
43 |
+
polygon = Polygon(polygon_points).convex_hull
|
44 |
+
return polygon
|
45 |
+
|
46 |
+
|
47 |
+
def polygon_iou(poly1, poly2):
|
48 |
+
"""
|
49 |
+
Intersection over union between two shapely polygons.
|
50 |
+
"""
|
51 |
+
if not poly1.intersects(
|
52 |
+
poly2): # this test is fast and can accelerate calculation
|
53 |
+
iou = 0
|
54 |
+
else:
|
55 |
+
try:
|
56 |
+
inter_area = poly1.intersection(poly2).area
|
57 |
+
union_area = poly1.area + poly2.area - inter_area
|
58 |
+
iou = float(inter_area) / union_area
|
59 |
+
except shapely.geos.TopologicalError:
|
60 |
+
# except Exception as e:
|
61 |
+
# print(e)
|
62 |
+
print('shapely.geos.TopologicalError occurred, iou set to 0')
|
63 |
+
iou = 0
|
64 |
+
return iou
|
65 |
+
|
66 |
+
|
67 |
+
def ed(str1, str2):
|
68 |
+
return editdistance.eval(str1, str2)
|
69 |
+
|
70 |
+
|
71 |
+
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
|
72 |
+
print('start testing...')
|
73 |
+
iou_thresh = 0.5
|
74 |
+
val_names = os.listdir(gt_dir)
|
75 |
+
num_gt_chars = 0
|
76 |
+
gt_count = 0
|
77 |
+
dt_count = 0
|
78 |
+
hit = 0
|
79 |
+
ed_sum = 0
|
80 |
+
|
81 |
+
for i, val_name in enumerate(val_names):
|
82 |
+
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
|
83 |
+
gt_lines = [o.strip() for o in f.readlines()]
|
84 |
+
gts = []
|
85 |
+
ignore_masks = []
|
86 |
+
for line in gt_lines:
|
87 |
+
parts = line.strip().split('\t')
|
88 |
+
# ignore illegal data
|
89 |
+
if len(parts) < 9:
|
90 |
+
continue
|
91 |
+
assert (len(parts) < 11)
|
92 |
+
if len(parts) == 9:
|
93 |
+
gts.append(parts[:8] + [''])
|
94 |
+
else:
|
95 |
+
gts.append(parts[:8] + [parts[-1]])
|
96 |
+
|
97 |
+
ignore_masks.append(parts[8])
|
98 |
+
|
99 |
+
val_path = os.path.join(res_dir, val_name)
|
100 |
+
if not os.path.exists(val_path):
|
101 |
+
dt_lines = []
|
102 |
+
else:
|
103 |
+
with open(val_path, encoding='utf-8') as f:
|
104 |
+
dt_lines = [o.strip() for o in f.readlines()]
|
105 |
+
dts = []
|
106 |
+
for line in dt_lines:
|
107 |
+
# print(line)
|
108 |
+
parts = line.strip().split("\t")
|
109 |
+
assert (len(parts) < 10), "line error: {}".format(line)
|
110 |
+
if len(parts) == 8:
|
111 |
+
dts.append(parts + [''])
|
112 |
+
else:
|
113 |
+
dts.append(parts)
|
114 |
+
|
115 |
+
dt_match = [False] * len(dts)
|
116 |
+
gt_match = [False] * len(gts)
|
117 |
+
all_ious = defaultdict(tuple)
|
118 |
+
for index_gt, gt in enumerate(gts):
|
119 |
+
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
|
120 |
+
gt_poly = polygon_from_str(gt_coors)
|
121 |
+
for index_dt, dt in enumerate(dts):
|
122 |
+
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
|
123 |
+
dt_poly = polygon_from_str(dt_coors)
|
124 |
+
iou = polygon_iou(dt_poly, gt_poly)
|
125 |
+
if iou >= iou_thresh:
|
126 |
+
all_ious[(index_gt, index_dt)] = iou
|
127 |
+
sorted_ious = sorted(
|
128 |
+
all_ious.items(), key=operator.itemgetter(1), reverse=True)
|
129 |
+
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
|
130 |
+
|
131 |
+
# matched gt and dt
|
132 |
+
for gt_dt_pair in sorted_gt_dt_pairs:
|
133 |
+
index_gt, index_dt = gt_dt_pair
|
134 |
+
if gt_match[index_gt] == False and dt_match[index_dt] == False:
|
135 |
+
gt_match[index_gt] = True
|
136 |
+
dt_match[index_dt] = True
|
137 |
+
if ignore_blank:
|
138 |
+
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
|
139 |
+
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
|
140 |
+
else:
|
141 |
+
gt_str = strQ2B(gts[index_gt][8])
|
142 |
+
dt_str = strQ2B(dts[index_dt][8])
|
143 |
+
if ignore_masks[index_gt] == '0':
|
144 |
+
ed_sum += ed(gt_str, dt_str)
|
145 |
+
num_gt_chars += len(gt_str)
|
146 |
+
if gt_str == dt_str:
|
147 |
+
hit += 1
|
148 |
+
gt_count += 1
|
149 |
+
dt_count += 1
|
150 |
+
|
151 |
+
# unmatched dt
|
152 |
+
for tindex, dt_match_flag in enumerate(dt_match):
|
153 |
+
if dt_match_flag == False:
|
154 |
+
dt_str = dts[tindex][8]
|
155 |
+
gt_str = ''
|
156 |
+
ed_sum += ed(dt_str, gt_str)
|
157 |
+
dt_count += 1
|
158 |
+
|
159 |
+
# unmatched gt
|
160 |
+
for tindex, gt_match_flag in enumerate(gt_match):
|
161 |
+
if gt_match_flag == False and ignore_masks[tindex] == '0':
|
162 |
+
dt_str = ''
|
163 |
+
gt_str = gts[tindex][8]
|
164 |
+
ed_sum += ed(gt_str, dt_str)
|
165 |
+
num_gt_chars += len(gt_str)
|
166 |
+
gt_count += 1
|
167 |
+
|
168 |
+
eps = 1e-9
|
169 |
+
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
|
170 |
+
precision = hit / (dt_count + eps)
|
171 |
+
recall = hit / (gt_count + eps)
|
172 |
+
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
|
173 |
+
avg_edit_dist_img = ed_sum / len(val_names)
|
174 |
+
avg_edit_dist_field = ed_sum / (gt_count + eps)
|
175 |
+
character_acc = 1 - ed_sum / (num_gt_chars + eps)
|
176 |
+
|
177 |
+
print('character_acc: %.2f' % (character_acc * 100) + "%")
|
178 |
+
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
|
179 |
+
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
|
180 |
+
print('precision: %.2f' % (precision * 100) + "%")
|
181 |
+
print('recall: %.2f' % (recall * 100) + "%")
|
182 |
+
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
# if len(sys.argv) != 3:
|
187 |
+
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
|
188 |
+
# exit(-1)
|
189 |
+
# gt_folder = sys.argv[1]
|
190 |
+
# pred_folder = sys.argv[2]
|
191 |
+
gt_folder = sys.argv[1]
|
192 |
+
pred_folder = sys.argv[2]
|
193 |
+
e2e_eval(gt_folder, pred_folder)
|
tools/end2end/readme.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 简介
|
3 |
+
|
4 |
+
`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
|
5 |
+
|
6 |
+
|
7 |
+
## 端对端评测步骤
|
8 |
+
|
9 |
+
**步骤一:**
|
10 |
+
|
11 |
+
运行`tools/infer/predict_system.py`,得到保存的结果:
|
12 |
+
|
13 |
+
```
|
14 |
+
python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
|
15 |
+
```
|
16 |
+
|
17 |
+
文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/system_results.txt`中,格式如下:
|
18 |
+
```
|
19 |
+
all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
**步骤二:**
|
24 |
+
|
25 |
+
将步骤一保存的数据转换为端对端评测需要的数据格式:
|
26 |
+
|
27 |
+
修改 `tools/end2end/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
|
28 |
+
|
29 |
+
```
|
30 |
+
python3 tools/end2end/convert_ppocr_label.py --mode=gt --label_path=path/to/label_txt --save_folder=save_gt_label
|
31 |
+
|
32 |
+
python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=path/to/pred_txt --save_folder=save_PPOCRV2_infer
|
33 |
+
```
|
34 |
+
|
35 |
+
得到如下结果:
|
36 |
+
```
|
37 |
+
├── ./save_gt_label/
|
38 |
+
├── ./save_PPOCRV2_infer/
|
39 |
+
```
|
40 |
+
|
41 |
+
**步骤三:**
|
42 |
+
|
43 |
+
执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
|
44 |
+
|
45 |
+
```
|
46 |
+
python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
|
47 |
+
```
|
48 |
+
|
49 |
+
比如:
|
50 |
+
|
51 |
+
```
|
52 |
+
python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
|
53 |
+
```
|
54 |
+
将得到如下结果,fmeasure为主要关注的指标:
|
55 |
+
```
|
56 |
+
hit, dt_count, gt_count 1557 2693 3283
|
57 |
+
character_acc: 61.77%
|
58 |
+
avg_edit_dist_field: 3.08
|
59 |
+
avg_edit_dist_img: 51.82
|
60 |
+
precision: 57.82%
|
61 |
+
recall: 47.43%
|
62 |
+
fmeasure: 52.11%
|
63 |
+
```
|
tools/eval.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
|
22 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
23 |
+
sys.path.insert(0, __dir__)
|
24 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
25 |
+
|
26 |
+
import paddle
|
27 |
+
from ppocr.data import build_dataloader
|
28 |
+
from ppocr.modeling.architectures import build_model
|
29 |
+
from ppocr.postprocess import build_post_process
|
30 |
+
from ppocr.metrics import build_metric
|
31 |
+
from ppocr.utils.save_load import load_model
|
32 |
+
import tools.program as program
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
global_config = config['Global']
|
37 |
+
# build dataloader
|
38 |
+
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
39 |
+
|
40 |
+
# build post process
|
41 |
+
post_process_class = build_post_process(config['PostProcess'],
|
42 |
+
global_config)
|
43 |
+
|
44 |
+
# build model
|
45 |
+
# for rec algorithm
|
46 |
+
if hasattr(post_process_class, 'character'):
|
47 |
+
char_num = len(getattr(post_process_class, 'character'))
|
48 |
+
if config['Architecture']["algorithm"] in ["Distillation",
|
49 |
+
]: # distillation model
|
50 |
+
for key in config['Architecture']["Models"]:
|
51 |
+
if config['Architecture']['Models'][key]['Head'][
|
52 |
+
'name'] == 'MultiHead': # for multi head
|
53 |
+
out_channels_list = {}
|
54 |
+
if config['PostProcess'][
|
55 |
+
'name'] == 'DistillationSARLabelDecode':
|
56 |
+
char_num = char_num - 2
|
57 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
58 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
59 |
+
config['Architecture']['Models'][key]['Head'][
|
60 |
+
'out_channels_list'] = out_channels_list
|
61 |
+
else:
|
62 |
+
config['Architecture']["Models"][key]["Head"][
|
63 |
+
'out_channels'] = char_num
|
64 |
+
elif config['Architecture']['Head'][
|
65 |
+
'name'] == 'MultiHead': # for multi head
|
66 |
+
out_channels_list = {}
|
67 |
+
if config['PostProcess']['name'] == 'SARLabelDecode':
|
68 |
+
char_num = char_num - 2
|
69 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
70 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
71 |
+
config['Architecture']['Head'][
|
72 |
+
'out_channels_list'] = out_channels_list
|
73 |
+
else: # base rec model
|
74 |
+
config['Architecture']["Head"]['out_channels'] = char_num
|
75 |
+
|
76 |
+
model = build_model(config['Architecture'])
|
77 |
+
extra_input_models = [
|
78 |
+
"SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
|
79 |
+
]
|
80 |
+
extra_input = False
|
81 |
+
if config['Architecture']['algorithm'] == 'Distillation':
|
82 |
+
for key in config['Architecture']["Models"]:
|
83 |
+
extra_input = extra_input or config['Architecture']['Models'][key][
|
84 |
+
'algorithm'] in extra_input_models
|
85 |
+
else:
|
86 |
+
extra_input = config['Architecture']['algorithm'] in extra_input_models
|
87 |
+
if "model_type" in config['Architecture'].keys():
|
88 |
+
if config['Architecture']['algorithm'] == 'CAN':
|
89 |
+
model_type = 'can'
|
90 |
+
else:
|
91 |
+
model_type = config['Architecture']['model_type']
|
92 |
+
else:
|
93 |
+
model_type = None
|
94 |
+
|
95 |
+
# build metric
|
96 |
+
eval_class = build_metric(config['Metric'])
|
97 |
+
# amp
|
98 |
+
use_amp = config["Global"].get("use_amp", False)
|
99 |
+
amp_level = config["Global"].get("amp_level", 'O2')
|
100 |
+
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
|
101 |
+
if use_amp:
|
102 |
+
AMP_RELATED_FLAGS_SETTING = {
|
103 |
+
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
104 |
+
'FLAGS_max_inplace_grad_add': 8,
|
105 |
+
}
|
106 |
+
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
107 |
+
scale_loss = config["Global"].get("scale_loss", 1.0)
|
108 |
+
use_dynamic_loss_scaling = config["Global"].get(
|
109 |
+
"use_dynamic_loss_scaling", False)
|
110 |
+
scaler = paddle.amp.GradScaler(
|
111 |
+
init_loss_scaling=scale_loss,
|
112 |
+
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
113 |
+
if amp_level == "O2":
|
114 |
+
model = paddle.amp.decorate(
|
115 |
+
models=model, level=amp_level, master_weight=True)
|
116 |
+
else:
|
117 |
+
scaler = None
|
118 |
+
|
119 |
+
best_model_dict = load_model(
|
120 |
+
config, model, model_type=config['Architecture']["model_type"])
|
121 |
+
if len(best_model_dict):
|
122 |
+
logger.info('metric in ckpt ***************')
|
123 |
+
for k, v in best_model_dict.items():
|
124 |
+
logger.info('{}:{}'.format(k, v))
|
125 |
+
|
126 |
+
# start eval
|
127 |
+
metric = program.eval(model, valid_dataloader, post_process_class,
|
128 |
+
eval_class, model_type, extra_input, scaler,
|
129 |
+
amp_level, amp_custom_black_list)
|
130 |
+
logger.info('metric eval ***************')
|
131 |
+
for k, v in metric.items():
|
132 |
+
logger.info('{}:{}'.format(k, v))
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
config, device, logger, vdl_writer = program.preprocess()
|
137 |
+
main()
|
tools/export_center.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
import pickle
|
22 |
+
|
23 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append(__dir__)
|
25 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
26 |
+
|
27 |
+
from ppocr.data import build_dataloader
|
28 |
+
from ppocr.modeling.architectures import build_model
|
29 |
+
from ppocr.postprocess import build_post_process
|
30 |
+
from ppocr.utils.save_load import load_model
|
31 |
+
from ppocr.utils.utility import print_dict
|
32 |
+
import tools.program as program
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
global_config = config['Global']
|
37 |
+
# build dataloader
|
38 |
+
config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
|
39 |
+
config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
|
40 |
+
'data_dir']
|
41 |
+
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
|
42 |
+
'label_file_list']
|
43 |
+
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
|
44 |
+
|
45 |
+
# build post process
|
46 |
+
post_process_class = build_post_process(config['PostProcess'],
|
47 |
+
global_config)
|
48 |
+
|
49 |
+
# build model
|
50 |
+
# for rec algorithm
|
51 |
+
if hasattr(post_process_class, 'character'):
|
52 |
+
char_num = len(getattr(post_process_class, 'character'))
|
53 |
+
config['Architecture']["Head"]['out_channels'] = char_num
|
54 |
+
|
55 |
+
#set return_features = True
|
56 |
+
config['Architecture']["Head"]["return_feats"] = True
|
57 |
+
|
58 |
+
model = build_model(config['Architecture'])
|
59 |
+
|
60 |
+
best_model_dict = load_model(config, model)
|
61 |
+
if len(best_model_dict):
|
62 |
+
logger.info('metric in ckpt ***************')
|
63 |
+
for k, v in best_model_dict.items():
|
64 |
+
logger.info('{}:{}'.format(k, v))
|
65 |
+
|
66 |
+
# get features from train data
|
67 |
+
char_center = program.get_center(model, eval_dataloader, post_process_class)
|
68 |
+
|
69 |
+
#serialize to disk
|
70 |
+
with open("train_center.pkl", 'wb') as f:
|
71 |
+
pickle.dump(char_center, f)
|
72 |
+
return
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
config, device, logger, vdl_writer = program.preprocess()
|
77 |
+
main()
|
tools/export_model.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
|
18 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
sys.path.append(__dir__)
|
20 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
|
21 |
+
|
22 |
+
import argparse
|
23 |
+
|
24 |
+
import paddle
|
25 |
+
from paddle.jit import to_static
|
26 |
+
|
27 |
+
from ppocr.modeling.architectures import build_model
|
28 |
+
from ppocr.postprocess import build_post_process
|
29 |
+
from ppocr.utils.save_load import load_model
|
30 |
+
from ppocr.utils.logging import get_logger
|
31 |
+
from tools.program import load_config, merge_config, ArgsParser
|
32 |
+
|
33 |
+
|
34 |
+
def export_single_model(model,
|
35 |
+
arch_config,
|
36 |
+
save_path,
|
37 |
+
logger,
|
38 |
+
input_shape=None,
|
39 |
+
quanter=None):
|
40 |
+
if arch_config["algorithm"] == "SRN":
|
41 |
+
max_text_length = arch_config["Head"]["max_text_length"]
|
42 |
+
other_shape = [
|
43 |
+
paddle.static.InputSpec(
|
44 |
+
shape=[None, 1, 64, 256], dtype="float32"), [
|
45 |
+
paddle.static.InputSpec(
|
46 |
+
shape=[None, 256, 1],
|
47 |
+
dtype="int64"), paddle.static.InputSpec(
|
48 |
+
shape=[None, max_text_length, 1], dtype="int64"),
|
49 |
+
paddle.static.InputSpec(
|
50 |
+
shape=[None, 8, max_text_length, max_text_length],
|
51 |
+
dtype="int64"), paddle.static.InputSpec(
|
52 |
+
shape=[None, 8, max_text_length, max_text_length],
|
53 |
+
dtype="int64")
|
54 |
+
]
|
55 |
+
]
|
56 |
+
model = to_static(model, input_spec=other_shape)
|
57 |
+
elif arch_config["algorithm"] == "SAR":
|
58 |
+
other_shape = [
|
59 |
+
paddle.static.InputSpec(
|
60 |
+
shape=[None, 3, 48, 160], dtype="float32"),
|
61 |
+
[paddle.static.InputSpec(
|
62 |
+
shape=[None], dtype="float32")]
|
63 |
+
]
|
64 |
+
model = to_static(model, input_spec=other_shape)
|
65 |
+
elif arch_config["algorithm"] == "SVTR":
|
66 |
+
if arch_config["Head"]["name"] == 'MultiHead':
|
67 |
+
other_shape = [
|
68 |
+
paddle.static.InputSpec(
|
69 |
+
shape=[None, 3, 48, -1], dtype="float32"),
|
70 |
+
]
|
71 |
+
else:
|
72 |
+
other_shape = [
|
73 |
+
paddle.static.InputSpec(
|
74 |
+
shape=[None] + input_shape, dtype="float32"),
|
75 |
+
]
|
76 |
+
model = to_static(model, input_spec=other_shape)
|
77 |
+
elif arch_config["algorithm"] == "PREN":
|
78 |
+
other_shape = [
|
79 |
+
paddle.static.InputSpec(
|
80 |
+
shape=[None, 3, 64, 256], dtype="float32"),
|
81 |
+
]
|
82 |
+
model = to_static(model, input_spec=other_shape)
|
83 |
+
elif arch_config["model_type"] == "sr":
|
84 |
+
other_shape = [
|
85 |
+
paddle.static.InputSpec(
|
86 |
+
shape=[None, 3, 16, 64], dtype="float32")
|
87 |
+
]
|
88 |
+
model = to_static(model, input_spec=other_shape)
|
89 |
+
elif arch_config["algorithm"] == "ViTSTR":
|
90 |
+
other_shape = [
|
91 |
+
paddle.static.InputSpec(
|
92 |
+
shape=[None, 1, 224, 224], dtype="float32"),
|
93 |
+
]
|
94 |
+
model = to_static(model, input_spec=other_shape)
|
95 |
+
elif arch_config["algorithm"] == "ABINet":
|
96 |
+
other_shape = [
|
97 |
+
paddle.static.InputSpec(
|
98 |
+
shape=[None, 3, 32, 128], dtype="float32"),
|
99 |
+
]
|
100 |
+
# print([None, 3, 32, 128])
|
101 |
+
model = to_static(model, input_spec=other_shape)
|
102 |
+
elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
|
103 |
+
other_shape = [
|
104 |
+
paddle.static.InputSpec(
|
105 |
+
shape=[None, 1, 32, 100], dtype="float32"),
|
106 |
+
]
|
107 |
+
model = to_static(model, input_spec=other_shape)
|
108 |
+
elif arch_config["algorithm"] == "VisionLAN":
|
109 |
+
other_shape = [
|
110 |
+
paddle.static.InputSpec(
|
111 |
+
shape=[None, 3, 64, 256], dtype="float32"),
|
112 |
+
]
|
113 |
+
model = to_static(model, input_spec=other_shape)
|
114 |
+
elif arch_config["algorithm"] == "RobustScanner":
|
115 |
+
max_text_length = arch_config["Head"]["max_text_length"]
|
116 |
+
other_shape = [
|
117 |
+
paddle.static.InputSpec(
|
118 |
+
shape=[None, 3, 48, 160], dtype="float32"), [
|
119 |
+
paddle.static.InputSpec(
|
120 |
+
shape=[None, ], dtype="float32"),
|
121 |
+
paddle.static.InputSpec(
|
122 |
+
shape=[None, max_text_length], dtype="int64")
|
123 |
+
]
|
124 |
+
]
|
125 |
+
model = to_static(model, input_spec=other_shape)
|
126 |
+
elif arch_config["algorithm"] == "CAN":
|
127 |
+
other_shape = [[
|
128 |
+
paddle.static.InputSpec(
|
129 |
+
shape=[None, 1, None, None],
|
130 |
+
dtype="float32"), paddle.static.InputSpec(
|
131 |
+
shape=[None, 1, None, None], dtype="float32"),
|
132 |
+
paddle.static.InputSpec(
|
133 |
+
shape=[None, arch_config['Head']['max_text_length']],
|
134 |
+
dtype="int64")
|
135 |
+
]]
|
136 |
+
model = to_static(model, input_spec=other_shape)
|
137 |
+
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
138 |
+
input_spec = [
|
139 |
+
paddle.static.InputSpec(
|
140 |
+
shape=[None, 512], dtype="int64"), # input_ids
|
141 |
+
paddle.static.InputSpec(
|
142 |
+
shape=[None, 512, 4], dtype="int64"), # bbox
|
143 |
+
paddle.static.InputSpec(
|
144 |
+
shape=[None, 512], dtype="int64"), # attention_mask
|
145 |
+
paddle.static.InputSpec(
|
146 |
+
shape=[None, 512], dtype="int64"), # token_type_ids
|
147 |
+
paddle.static.InputSpec(
|
148 |
+
shape=[None, 3, 224, 224], dtype="int64"), # image
|
149 |
+
]
|
150 |
+
if 'Re' in arch_config['Backbone']['name']:
|
151 |
+
input_spec.extend([
|
152 |
+
paddle.static.InputSpec(
|
153 |
+
shape=[None, 512, 3], dtype="int64"), # entities
|
154 |
+
paddle.static.InputSpec(
|
155 |
+
shape=[None, None, 2], dtype="int64"), # relations
|
156 |
+
])
|
157 |
+
if model.backbone.use_visual_backbone is False:
|
158 |
+
input_spec.pop(4)
|
159 |
+
model = to_static(model, input_spec=[input_spec])
|
160 |
+
else:
|
161 |
+
infer_shape = [3, -1, -1]
|
162 |
+
if arch_config["model_type"] == "rec":
|
163 |
+
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
164 |
+
if "Transform" in arch_config and arch_config[
|
165 |
+
"Transform"] is not None and arch_config["Transform"][
|
166 |
+
"name"] == "TPS":
|
167 |
+
logger.info(
|
168 |
+
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
169 |
+
)
|
170 |
+
infer_shape[-1] = 100
|
171 |
+
elif arch_config["model_type"] == "table":
|
172 |
+
infer_shape = [3, 488, 488]
|
173 |
+
if arch_config["algorithm"] == "TableMaster":
|
174 |
+
infer_shape = [3, 480, 480]
|
175 |
+
if arch_config["algorithm"] == "SLANet":
|
176 |
+
infer_shape = [3, -1, -1]
|
177 |
+
model = to_static(
|
178 |
+
model,
|
179 |
+
input_spec=[
|
180 |
+
paddle.static.InputSpec(
|
181 |
+
shape=[None] + infer_shape, dtype="float32")
|
182 |
+
])
|
183 |
+
|
184 |
+
if quanter is None:
|
185 |
+
paddle.jit.save(model, save_path)
|
186 |
+
else:
|
187 |
+
quanter.save_quantized_model(model, save_path)
|
188 |
+
logger.info("inference model is saved to {}".format(save_path))
|
189 |
+
return
|
190 |
+
|
191 |
+
|
192 |
+
def main():
|
193 |
+
FLAGS = ArgsParser().parse_args()
|
194 |
+
config = load_config(FLAGS.config)
|
195 |
+
config = merge_config(config, FLAGS.opt)
|
196 |
+
logger = get_logger()
|
197 |
+
# build post process
|
198 |
+
|
199 |
+
post_process_class = build_post_process(config["PostProcess"],
|
200 |
+
config["Global"])
|
201 |
+
|
202 |
+
# build model
|
203 |
+
# for rec algorithm
|
204 |
+
if hasattr(post_process_class, "character"):
|
205 |
+
char_num = len(getattr(post_process_class, "character"))
|
206 |
+
if config["Architecture"]["algorithm"] in ["Distillation",
|
207 |
+
]: # distillation model
|
208 |
+
for key in config["Architecture"]["Models"]:
|
209 |
+
if config["Architecture"]["Models"][key]["Head"][
|
210 |
+
"name"] == 'MultiHead': # multi head
|
211 |
+
out_channels_list = {}
|
212 |
+
if config['PostProcess'][
|
213 |
+
'name'] == 'DistillationSARLabelDecode':
|
214 |
+
char_num = char_num - 2
|
215 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
216 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
217 |
+
config['Architecture']['Models'][key]['Head'][
|
218 |
+
'out_channels_list'] = out_channels_list
|
219 |
+
else:
|
220 |
+
config["Architecture"]["Models"][key]["Head"][
|
221 |
+
"out_channels"] = char_num
|
222 |
+
# just one final tensor needs to exported for inference
|
223 |
+
config["Architecture"]["Models"][key][
|
224 |
+
"return_all_feats"] = False
|
225 |
+
elif config['Architecture']['Head'][
|
226 |
+
'name'] == 'MultiHead': # multi head
|
227 |
+
out_channels_list = {}
|
228 |
+
char_num = len(getattr(post_process_class, 'character'))
|
229 |
+
if config['PostProcess']['name'] == 'SARLabelDecode':
|
230 |
+
char_num = char_num - 2
|
231 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
232 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
233 |
+
config['Architecture']['Head'][
|
234 |
+
'out_channels_list'] = out_channels_list
|
235 |
+
else: # base rec model
|
236 |
+
config["Architecture"]["Head"]["out_channels"] = char_num
|
237 |
+
|
238 |
+
# for sr algorithm
|
239 |
+
if config["Architecture"]["model_type"] == "sr":
|
240 |
+
config['Architecture']["Transform"]['infer_mode'] = True
|
241 |
+
model = build_model(config["Architecture"])
|
242 |
+
load_model(config, model, model_type=config['Architecture']["model_type"])
|
243 |
+
model.eval()
|
244 |
+
|
245 |
+
save_path = config["Global"]["save_inference_dir"]
|
246 |
+
|
247 |
+
arch_config = config["Architecture"]
|
248 |
+
|
249 |
+
if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
|
250 |
+
"name"] != 'MultiHead':
|
251 |
+
input_shape = config["Eval"]["dataset"]["transforms"][-2][
|
252 |
+
'SVTRRecResizeImg']['image_shape']
|
253 |
+
else:
|
254 |
+
input_shape = None
|
255 |
+
|
256 |
+
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
257 |
+
archs = list(arch_config["Models"].values())
|
258 |
+
for idx, name in enumerate(model.model_name_list):
|
259 |
+
sub_model_save_path = os.path.join(save_path, name, "inference")
|
260 |
+
export_single_model(model.model_list[idx], archs[idx],
|
261 |
+
sub_model_save_path, logger)
|
262 |
+
else:
|
263 |
+
save_path = os.path.join(save_path, "inference")
|
264 |
+
export_single_model(
|
265 |
+
model, arch_config, save_path, logger, input_shape=input_shape)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == "__main__":
|
269 |
+
main()
|
tools/infer/__pycache__/predict_cls.cpython-37.pyc
ADDED
Binary file (4.07 kB). View file
|
|
tools/infer/__pycache__/predict_cls.cpython-38.pyc
ADDED
Binary file (4.11 kB). View file
|
|
tools/infer/__pycache__/predict_det.cpython-37.pyc
ADDED
Binary file (8.48 kB). View file
|
|
tools/infer/__pycache__/predict_det.cpython-38.pyc
ADDED
Binary file (8.61 kB). View file
|
|
tools/infer/__pycache__/predict_rec.cpython-37.pyc
ADDED
Binary file (13.8 kB). View file
|
|
tools/infer/__pycache__/predict_rec.cpython-38.pyc
ADDED
Binary file (13.8 kB). View file
|
|
tools/infer/__pycache__/predict_system.cpython-37.pyc
ADDED
Binary file (7.04 kB). View file
|
|
tools/infer/__pycache__/predict_system.cpython-38.pyc
ADDED
Binary file (7.13 kB). View file
|
|
tools/infer/__pycache__/utility.cpython-37.pyc
ADDED
Binary file (17.4 kB). View file
|
|
tools/infer/__pycache__/utility.cpython-38.pyc
ADDED
Binary file (17.4 kB). View file
|
|
tools/infer/predict_cls.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(__dir__)
|
19 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
20 |
+
|
21 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import copy
|
25 |
+
import numpy as np
|
26 |
+
import math
|
27 |
+
import time
|
28 |
+
import traceback
|
29 |
+
|
30 |
+
import tools.infer.utility as utility
|
31 |
+
from ppocr.postprocess import build_post_process
|
32 |
+
from ppocr.utils.logging import get_logger
|
33 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
34 |
+
|
35 |
+
logger = get_logger()
|
36 |
+
|
37 |
+
|
38 |
+
class TextClassifier(object):
|
39 |
+
def __init__(self, args):
|
40 |
+
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
41 |
+
self.cls_batch_num = args.cls_batch_num
|
42 |
+
self.cls_thresh = args.cls_thresh
|
43 |
+
postprocess_params = {
|
44 |
+
'name': 'ClsPostProcess',
|
45 |
+
"label_list": args.label_list,
|
46 |
+
}
|
47 |
+
self.postprocess_op = build_post_process(postprocess_params)
|
48 |
+
self.predictor, self.input_tensor, self.output_tensors, _ = \
|
49 |
+
utility.create_predictor(args, 'cls', logger)
|
50 |
+
self.use_onnx = args.use_onnx
|
51 |
+
|
52 |
+
def resize_norm_img(self, img):
|
53 |
+
imgC, imgH, imgW = self.cls_image_shape
|
54 |
+
h = img.shape[0]
|
55 |
+
w = img.shape[1]
|
56 |
+
ratio = w / float(h)
|
57 |
+
if math.ceil(imgH * ratio) > imgW:
|
58 |
+
resized_w = imgW
|
59 |
+
else:
|
60 |
+
resized_w = int(math.ceil(imgH * ratio))
|
61 |
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
62 |
+
resized_image = resized_image.astype('float32')
|
63 |
+
if self.cls_image_shape[0] == 1:
|
64 |
+
resized_image = resized_image / 255
|
65 |
+
resized_image = resized_image[np.newaxis, :]
|
66 |
+
else:
|
67 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
68 |
+
resized_image -= 0.5
|
69 |
+
resized_image /= 0.5
|
70 |
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
71 |
+
padding_im[:, :, 0:resized_w] = resized_image
|
72 |
+
return padding_im
|
73 |
+
|
74 |
+
def __call__(self, img_list):
|
75 |
+
img_list = copy.deepcopy(img_list)
|
76 |
+
img_num = len(img_list)
|
77 |
+
# Calculate the aspect ratio of all text bars
|
78 |
+
width_list = []
|
79 |
+
for img in img_list:
|
80 |
+
width_list.append(img.shape[1] / float(img.shape[0]))
|
81 |
+
# Sorting can speed up the cls process
|
82 |
+
indices = np.argsort(np.array(width_list))
|
83 |
+
|
84 |
+
cls_res = [['', 0.0]] * img_num
|
85 |
+
batch_num = self.cls_batch_num
|
86 |
+
elapse = 0
|
87 |
+
for beg_img_no in range(0, img_num, batch_num):
|
88 |
+
|
89 |
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
90 |
+
norm_img_batch = []
|
91 |
+
max_wh_ratio = 0
|
92 |
+
starttime = time.time()
|
93 |
+
for ino in range(beg_img_no, end_img_no):
|
94 |
+
h, w = img_list[indices[ino]].shape[0:2]
|
95 |
+
wh_ratio = w * 1.0 / h
|
96 |
+
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
97 |
+
for ino in range(beg_img_no, end_img_no):
|
98 |
+
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
99 |
+
norm_img = norm_img[np.newaxis, :]
|
100 |
+
norm_img_batch.append(norm_img)
|
101 |
+
norm_img_batch = np.concatenate(norm_img_batch)
|
102 |
+
norm_img_batch = norm_img_batch.copy()
|
103 |
+
|
104 |
+
if self.use_onnx:
|
105 |
+
input_dict = {}
|
106 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
107 |
+
outputs = self.predictor.run(self.output_tensors, input_dict)
|
108 |
+
prob_out = outputs[0]
|
109 |
+
else:
|
110 |
+
self.input_tensor.copy_from_cpu(norm_img_batch)
|
111 |
+
self.predictor.run()
|
112 |
+
prob_out = self.output_tensors[0].copy_to_cpu()
|
113 |
+
self.predictor.try_shrink_memory()
|
114 |
+
cls_result = self.postprocess_op(prob_out)
|
115 |
+
elapse += time.time() - starttime
|
116 |
+
for rno in range(len(cls_result)):
|
117 |
+
label, score = cls_result[rno]
|
118 |
+
cls_res[indices[beg_img_no + rno]] = [label, score]
|
119 |
+
if '180' in label and score > self.cls_thresh:
|
120 |
+
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
121 |
+
img_list[indices[beg_img_no + rno]], 1)
|
122 |
+
return img_list, cls_res, elapse
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
image_file_list = get_image_file_list(args.image_dir)
|
127 |
+
text_classifier = TextClassifier(args)
|
128 |
+
valid_image_file_list = []
|
129 |
+
img_list = []
|
130 |
+
for image_file in image_file_list:
|
131 |
+
img, flag, _ = check_and_read(image_file)
|
132 |
+
if not flag:
|
133 |
+
img = cv2.imread(image_file)
|
134 |
+
if img is None:
|
135 |
+
logger.info("error in loading image:{}".format(image_file))
|
136 |
+
continue
|
137 |
+
valid_image_file_list.append(image_file)
|
138 |
+
img_list.append(img)
|
139 |
+
try:
|
140 |
+
img_list, cls_res, predict_time = text_classifier(img_list)
|
141 |
+
except Exception as E:
|
142 |
+
logger.info(traceback.format_exc())
|
143 |
+
logger.info(E)
|
144 |
+
exit()
|
145 |
+
for ino in range(len(img_list)):
|
146 |
+
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
147 |
+
cls_res[ino]))
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
main(utility.parse_args())
|
tools/infer/predict_det.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(__dir__)
|
19 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
20 |
+
|
21 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import numpy as np
|
25 |
+
import time
|
26 |
+
import sys
|
27 |
+
|
28 |
+
import tools.infer.utility as utility
|
29 |
+
from ppocr.utils.logging import get_logger
|
30 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
31 |
+
from ppocr.data import create_operators, transform
|
32 |
+
from ppocr.postprocess import build_post_process
|
33 |
+
import json
|
34 |
+
logger = get_logger()
|
35 |
+
|
36 |
+
|
37 |
+
class TextDetector(object):
|
38 |
+
def __init__(self, args):
|
39 |
+
self.args = args
|
40 |
+
self.det_algorithm = args.det_algorithm
|
41 |
+
self.use_onnx = args.use_onnx
|
42 |
+
pre_process_list = [{
|
43 |
+
'DetResizeForTest': {
|
44 |
+
'limit_side_len': args.det_limit_side_len,
|
45 |
+
'limit_type': args.det_limit_type,
|
46 |
+
}
|
47 |
+
}, {
|
48 |
+
'NormalizeImage': {
|
49 |
+
'std': [0.229, 0.224, 0.225],
|
50 |
+
'mean': [0.485, 0.456, 0.406],
|
51 |
+
'scale': '1./255.',
|
52 |
+
'order': 'hwc'
|
53 |
+
}
|
54 |
+
}, {
|
55 |
+
'ToCHWImage': None
|
56 |
+
}, {
|
57 |
+
'KeepKeys': {
|
58 |
+
'keep_keys': ['image', 'shape']
|
59 |
+
}
|
60 |
+
}]
|
61 |
+
postprocess_params = {}
|
62 |
+
if self.det_algorithm == "DB":
|
63 |
+
postprocess_params['name'] = 'DBPostProcess'
|
64 |
+
postprocess_params["thresh"] = args.det_db_thresh
|
65 |
+
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
66 |
+
postprocess_params["max_candidates"] = 1000
|
67 |
+
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
68 |
+
postprocess_params["use_dilation"] = args.use_dilation
|
69 |
+
postprocess_params["score_mode"] = args.det_db_score_mode
|
70 |
+
postprocess_params["box_type"] = args.det_box_type
|
71 |
+
elif self.det_algorithm == "DB++":
|
72 |
+
postprocess_params['name'] = 'DBPostProcess'
|
73 |
+
postprocess_params["thresh"] = args.det_db_thresh
|
74 |
+
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
75 |
+
postprocess_params["max_candidates"] = 1000
|
76 |
+
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
77 |
+
postprocess_params["use_dilation"] = args.use_dilation
|
78 |
+
postprocess_params["score_mode"] = args.det_db_score_mode
|
79 |
+
postprocess_params["box_type"] = args.det_box_type
|
80 |
+
pre_process_list[1] = {
|
81 |
+
'NormalizeImage': {
|
82 |
+
'std': [1.0, 1.0, 1.0],
|
83 |
+
'mean':
|
84 |
+
[0.48109378172549, 0.45752457890196, 0.40787054090196],
|
85 |
+
'scale': '1./255.',
|
86 |
+
'order': 'hwc'
|
87 |
+
}
|
88 |
+
}
|
89 |
+
elif self.det_algorithm == "EAST":
|
90 |
+
postprocess_params['name'] = 'EASTPostProcess'
|
91 |
+
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
92 |
+
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
93 |
+
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
94 |
+
elif self.det_algorithm == "SAST":
|
95 |
+
pre_process_list[0] = {
|
96 |
+
'DetResizeForTest': {
|
97 |
+
'resize_long': args.det_limit_side_len
|
98 |
+
}
|
99 |
+
}
|
100 |
+
postprocess_params['name'] = 'SASTPostProcess'
|
101 |
+
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
102 |
+
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
103 |
+
|
104 |
+
if args.det_box_type == 'poly':
|
105 |
+
postprocess_params["sample_pts_num"] = 6
|
106 |
+
postprocess_params["expand_scale"] = 1.2
|
107 |
+
postprocess_params["shrink_ratio_of_width"] = 0.2
|
108 |
+
else:
|
109 |
+
postprocess_params["sample_pts_num"] = 2
|
110 |
+
postprocess_params["expand_scale"] = 1.0
|
111 |
+
postprocess_params["shrink_ratio_of_width"] = 0.3
|
112 |
+
|
113 |
+
elif self.det_algorithm == "PSE":
|
114 |
+
postprocess_params['name'] = 'PSEPostProcess'
|
115 |
+
postprocess_params["thresh"] = args.det_pse_thresh
|
116 |
+
postprocess_params["box_thresh"] = args.det_pse_box_thresh
|
117 |
+
postprocess_params["min_area"] = args.det_pse_min_area
|
118 |
+
postprocess_params["box_type"] = args.det_box_type
|
119 |
+
postprocess_params["scale"] = args.det_pse_scale
|
120 |
+
elif self.det_algorithm == "FCE":
|
121 |
+
pre_process_list[0] = {
|
122 |
+
'DetResizeForTest': {
|
123 |
+
'rescale_img': [1080, 736]
|
124 |
+
}
|
125 |
+
}
|
126 |
+
postprocess_params['name'] = 'FCEPostProcess'
|
127 |
+
postprocess_params["scales"] = args.scales
|
128 |
+
postprocess_params["alpha"] = args.alpha
|
129 |
+
postprocess_params["beta"] = args.beta
|
130 |
+
postprocess_params["fourier_degree"] = args.fourier_degree
|
131 |
+
postprocess_params["box_type"] = args.det_box_type
|
132 |
+
elif self.det_algorithm == "CT":
|
133 |
+
pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
|
134 |
+
postprocess_params['name'] = 'CTPostProcess'
|
135 |
+
else:
|
136 |
+
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
137 |
+
sys.exit(0)
|
138 |
+
|
139 |
+
self.preprocess_op = create_operators(pre_process_list)
|
140 |
+
self.postprocess_op = build_post_process(postprocess_params)
|
141 |
+
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
142 |
+
args, 'det', logger)
|
143 |
+
|
144 |
+
if self.use_onnx:
|
145 |
+
img_h, img_w = self.input_tensor.shape[2:]
|
146 |
+
if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
147 |
+
pre_process_list[0] = {
|
148 |
+
'DetResizeForTest': {
|
149 |
+
'image_shape': [img_h, img_w]
|
150 |
+
}
|
151 |
+
}
|
152 |
+
self.preprocess_op = create_operators(pre_process_list)
|
153 |
+
|
154 |
+
if args.benchmark:
|
155 |
+
import auto_log
|
156 |
+
pid = os.getpid()
|
157 |
+
gpu_id = utility.get_infer_gpuid()
|
158 |
+
self.autolog = auto_log.AutoLogger(
|
159 |
+
model_name="det",
|
160 |
+
model_precision=args.precision,
|
161 |
+
batch_size=1,
|
162 |
+
data_shape="dynamic",
|
163 |
+
save_path=None,
|
164 |
+
inference_config=self.config,
|
165 |
+
pids=pid,
|
166 |
+
process_name=None,
|
167 |
+
gpu_ids=gpu_id if args.use_gpu else None,
|
168 |
+
time_keys=[
|
169 |
+
'preprocess_time', 'inference_time', 'postprocess_time'
|
170 |
+
],
|
171 |
+
warmup=2,
|
172 |
+
logger=logger)
|
173 |
+
|
174 |
+
def order_points_clockwise(self, pts):
|
175 |
+
rect = np.zeros((4, 2), dtype="float32")
|
176 |
+
s = pts.sum(axis=1)
|
177 |
+
rect[0] = pts[np.argmin(s)]
|
178 |
+
rect[2] = pts[np.argmax(s)]
|
179 |
+
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
180 |
+
diff = np.diff(np.array(tmp), axis=1)
|
181 |
+
rect[1] = tmp[np.argmin(diff)]
|
182 |
+
rect[3] = tmp[np.argmax(diff)]
|
183 |
+
return rect
|
184 |
+
|
185 |
+
def clip_det_res(self, points, img_height, img_width):
|
186 |
+
for pno in range(points.shape[0]):
|
187 |
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
188 |
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
189 |
+
return points
|
190 |
+
|
191 |
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
192 |
+
img_height, img_width = image_shape[0:2]
|
193 |
+
dt_boxes_new = []
|
194 |
+
for box in dt_boxes:
|
195 |
+
if type(box) is list:
|
196 |
+
box = np.array(box)
|
197 |
+
box = self.order_points_clockwise(box)
|
198 |
+
box = self.clip_det_res(box, img_height, img_width)
|
199 |
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
200 |
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
201 |
+
if rect_width <= 3 or rect_height <= 3:
|
202 |
+
continue
|
203 |
+
dt_boxes_new.append(box)
|
204 |
+
dt_boxes = np.array(dt_boxes_new)
|
205 |
+
return dt_boxes
|
206 |
+
|
207 |
+
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
208 |
+
img_height, img_width = image_shape[0:2]
|
209 |
+
dt_boxes_new = []
|
210 |
+
for box in dt_boxes:
|
211 |
+
if type(box) is list:
|
212 |
+
box = np.array(box)
|
213 |
+
box = self.clip_det_res(box, img_height, img_width)
|
214 |
+
dt_boxes_new.append(box)
|
215 |
+
dt_boxes = np.array(dt_boxes_new)
|
216 |
+
return dt_boxes
|
217 |
+
|
218 |
+
def __call__(self, img):
|
219 |
+
ori_im = img.copy()
|
220 |
+
data = {'image': img}
|
221 |
+
|
222 |
+
st = time.time()
|
223 |
+
|
224 |
+
if self.args.benchmark:
|
225 |
+
self.autolog.times.start()
|
226 |
+
|
227 |
+
data = transform(data, self.preprocess_op)
|
228 |
+
img, shape_list = data
|
229 |
+
if img is None:
|
230 |
+
return None, 0
|
231 |
+
img = np.expand_dims(img, axis=0)
|
232 |
+
shape_list = np.expand_dims(shape_list, axis=0)
|
233 |
+
img = img.copy()
|
234 |
+
|
235 |
+
if self.args.benchmark:
|
236 |
+
self.autolog.times.stamp()
|
237 |
+
if self.use_onnx:
|
238 |
+
input_dict = {}
|
239 |
+
input_dict[self.input_tensor.name] = img
|
240 |
+
outputs = self.predictor.run(self.output_tensors, input_dict)
|
241 |
+
else:
|
242 |
+
self.input_tensor.copy_from_cpu(img)
|
243 |
+
self.predictor.run()
|
244 |
+
outputs = []
|
245 |
+
for output_tensor in self.output_tensors:
|
246 |
+
output = output_tensor.copy_to_cpu()
|
247 |
+
outputs.append(output)
|
248 |
+
if self.args.benchmark:
|
249 |
+
self.autolog.times.stamp()
|
250 |
+
|
251 |
+
preds = {}
|
252 |
+
if self.det_algorithm == "EAST":
|
253 |
+
preds['f_geo'] = outputs[0]
|
254 |
+
preds['f_score'] = outputs[1]
|
255 |
+
elif self.det_algorithm == 'SAST':
|
256 |
+
preds['f_border'] = outputs[0]
|
257 |
+
preds['f_score'] = outputs[1]
|
258 |
+
preds['f_tco'] = outputs[2]
|
259 |
+
preds['f_tvo'] = outputs[3]
|
260 |
+
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
|
261 |
+
preds['maps'] = outputs[0]
|
262 |
+
elif self.det_algorithm == 'FCE':
|
263 |
+
for i, output in enumerate(outputs):
|
264 |
+
preds['level_{}'.format(i)] = output
|
265 |
+
elif self.det_algorithm == "CT":
|
266 |
+
preds['maps'] = outputs[0]
|
267 |
+
preds['score'] = outputs[1]
|
268 |
+
else:
|
269 |
+
raise NotImplementedError
|
270 |
+
|
271 |
+
post_result = self.postprocess_op(preds, shape_list)
|
272 |
+
dt_boxes = post_result[0]['points']
|
273 |
+
|
274 |
+
if self.args.det_box_type == 'poly':
|
275 |
+
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
276 |
+
else:
|
277 |
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
278 |
+
|
279 |
+
if self.args.benchmark:
|
280 |
+
self.autolog.times.end(stamp=True)
|
281 |
+
et = time.time()
|
282 |
+
return dt_boxes, et - st
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
args = utility.parse_args()
|
287 |
+
image_file_list = get_image_file_list(args.image_dir)
|
288 |
+
text_detector = TextDetector(args)
|
289 |
+
total_time = 0
|
290 |
+
draw_img_save_dir = args.draw_img_save_dir
|
291 |
+
os.makedirs(draw_img_save_dir, exist_ok=True)
|
292 |
+
|
293 |
+
if args.warmup:
|
294 |
+
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
295 |
+
for i in range(2):
|
296 |
+
res = text_detector(img)
|
297 |
+
|
298 |
+
save_results = []
|
299 |
+
for idx, image_file in enumerate(image_file_list):
|
300 |
+
img, flag_gif, flag_pdf = check_and_read(image_file)
|
301 |
+
if not flag_gif and not flag_pdf:
|
302 |
+
img = cv2.imread(image_file)
|
303 |
+
if not flag_pdf:
|
304 |
+
if img is None:
|
305 |
+
logger.debug("error in loading image:{}".format(image_file))
|
306 |
+
continue
|
307 |
+
imgs = [img]
|
308 |
+
else:
|
309 |
+
page_num = args.page_num
|
310 |
+
if page_num > len(img) or page_num == 0:
|
311 |
+
page_num = len(img)
|
312 |
+
imgs = img[:page_num]
|
313 |
+
for index, img in enumerate(imgs):
|
314 |
+
st = time.time()
|
315 |
+
dt_boxes, _ = text_detector(img)
|
316 |
+
elapse = time.time() - st
|
317 |
+
total_time += elapse
|
318 |
+
if len(imgs) > 1:
|
319 |
+
save_pred = os.path.basename(image_file) + '_' + str(
|
320 |
+
index) + "\t" + str(
|
321 |
+
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
|
322 |
+
else:
|
323 |
+
save_pred = os.path.basename(image_file) + "\t" + str(
|
324 |
+
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
|
325 |
+
save_results.append(save_pred)
|
326 |
+
logger.info(save_pred)
|
327 |
+
if len(imgs) > 1:
|
328 |
+
logger.info("{}_{} The predict time of {}: {}".format(
|
329 |
+
idx, index, image_file, elapse))
|
330 |
+
else:
|
331 |
+
logger.info("{} The predict time of {}: {}".format(
|
332 |
+
idx, image_file, elapse))
|
333 |
+
|
334 |
+
src_im = utility.draw_text_det_res(dt_boxes, img)
|
335 |
+
|
336 |
+
if flag_gif:
|
337 |
+
save_file = image_file[:-3] + "png"
|
338 |
+
elif flag_pdf:
|
339 |
+
save_file = image_file.replace('.pdf',
|
340 |
+
'_' + str(index) + '.png')
|
341 |
+
else:
|
342 |
+
save_file = image_file
|
343 |
+
img_path = os.path.join(
|
344 |
+
draw_img_save_dir,
|
345 |
+
"det_res_{}".format(os.path.basename(save_file)))
|
346 |
+
cv2.imwrite(img_path, src_im)
|
347 |
+
logger.info("The visualized image saved in {}".format(img_path))
|
348 |
+
|
349 |
+
with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
|
350 |
+
f.writelines(save_results)
|
351 |
+
f.close()
|
352 |
+
if args.benchmark:
|
353 |
+
text_detector.autolog.report()
|
tools/infer/predict_e2e.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
|
17 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(__dir__)
|
19 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
20 |
+
|
21 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import numpy as np
|
25 |
+
import time
|
26 |
+
import sys
|
27 |
+
|
28 |
+
import tools.infer.utility as utility
|
29 |
+
from ppocr.utils.logging import get_logger
|
30 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
31 |
+
from ppocr.data import create_operators, transform
|
32 |
+
from ppocr.postprocess import build_post_process
|
33 |
+
|
34 |
+
logger = get_logger()
|
35 |
+
|
36 |
+
|
37 |
+
class TextE2E(object):
|
38 |
+
def __init__(self, args):
|
39 |
+
self.args = args
|
40 |
+
self.e2e_algorithm = args.e2e_algorithm
|
41 |
+
self.use_onnx = args.use_onnx
|
42 |
+
pre_process_list = [{
|
43 |
+
'E2EResizeForTest': {}
|
44 |
+
}, {
|
45 |
+
'NormalizeImage': {
|
46 |
+
'std': [0.229, 0.224, 0.225],
|
47 |
+
'mean': [0.485, 0.456, 0.406],
|
48 |
+
'scale': '1./255.',
|
49 |
+
'order': 'hwc'
|
50 |
+
}
|
51 |
+
}, {
|
52 |
+
'ToCHWImage': None
|
53 |
+
}, {
|
54 |
+
'KeepKeys': {
|
55 |
+
'keep_keys': ['image', 'shape']
|
56 |
+
}
|
57 |
+
}]
|
58 |
+
postprocess_params = {}
|
59 |
+
if self.e2e_algorithm == "PGNet":
|
60 |
+
pre_process_list[0] = {
|
61 |
+
'E2EResizeForTest': {
|
62 |
+
'max_side_len': args.e2e_limit_side_len,
|
63 |
+
'valid_set': 'totaltext'
|
64 |
+
}
|
65 |
+
}
|
66 |
+
postprocess_params['name'] = 'PGPostProcess'
|
67 |
+
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
|
68 |
+
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
|
69 |
+
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
|
70 |
+
postprocess_params["mode"] = args.e2e_pgnet_mode
|
71 |
+
else:
|
72 |
+
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
|
73 |
+
sys.exit(0)
|
74 |
+
|
75 |
+
self.preprocess_op = create_operators(pre_process_list)
|
76 |
+
self.postprocess_op = build_post_process(postprocess_params)
|
77 |
+
self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
|
78 |
+
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
|
79 |
+
# self.predictor.eval()
|
80 |
+
|
81 |
+
def clip_det_res(self, points, img_height, img_width):
|
82 |
+
for pno in range(points.shape[0]):
|
83 |
+
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
84 |
+
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
85 |
+
return points
|
86 |
+
|
87 |
+
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
88 |
+
img_height, img_width = image_shape[0:2]
|
89 |
+
dt_boxes_new = []
|
90 |
+
for box in dt_boxes:
|
91 |
+
box = self.clip_det_res(box, img_height, img_width)
|
92 |
+
dt_boxes_new.append(box)
|
93 |
+
dt_boxes = np.array(dt_boxes_new)
|
94 |
+
return dt_boxes
|
95 |
+
|
96 |
+
def __call__(self, img):
|
97 |
+
|
98 |
+
ori_im = img.copy()
|
99 |
+
data = {'image': img}
|
100 |
+
data = transform(data, self.preprocess_op)
|
101 |
+
img, shape_list = data
|
102 |
+
if img is None:
|
103 |
+
return None, 0
|
104 |
+
img = np.expand_dims(img, axis=0)
|
105 |
+
shape_list = np.expand_dims(shape_list, axis=0)
|
106 |
+
img = img.copy()
|
107 |
+
starttime = time.time()
|
108 |
+
|
109 |
+
if self.use_onnx:
|
110 |
+
input_dict = {}
|
111 |
+
input_dict[self.input_tensor.name] = img
|
112 |
+
outputs = self.predictor.run(self.output_tensors, input_dict)
|
113 |
+
preds = {}
|
114 |
+
preds['f_border'] = outputs[0]
|
115 |
+
preds['f_char'] = outputs[1]
|
116 |
+
preds['f_direction'] = outputs[2]
|
117 |
+
preds['f_score'] = outputs[3]
|
118 |
+
else:
|
119 |
+
self.input_tensor.copy_from_cpu(img)
|
120 |
+
self.predictor.run()
|
121 |
+
outputs = []
|
122 |
+
for output_tensor in self.output_tensors:
|
123 |
+
output = output_tensor.copy_to_cpu()
|
124 |
+
outputs.append(output)
|
125 |
+
|
126 |
+
preds = {}
|
127 |
+
if self.e2e_algorithm == 'PGNet':
|
128 |
+
preds['f_border'] = outputs[0]
|
129 |
+
preds['f_char'] = outputs[1]
|
130 |
+
preds['f_direction'] = outputs[2]
|
131 |
+
preds['f_score'] = outputs[3]
|
132 |
+
else:
|
133 |
+
raise NotImplementedError
|
134 |
+
post_result = self.postprocess_op(preds, shape_list)
|
135 |
+
points, strs = post_result['points'], post_result['texts']
|
136 |
+
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
137 |
+
elapse = time.time() - starttime
|
138 |
+
return dt_boxes, strs, elapse
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
args = utility.parse_args()
|
143 |
+
image_file_list = get_image_file_list(args.image_dir)
|
144 |
+
text_detector = TextE2E(args)
|
145 |
+
count = 0
|
146 |
+
total_time = 0
|
147 |
+
draw_img_save = "./inference_results"
|
148 |
+
if not os.path.exists(draw_img_save):
|
149 |
+
os.makedirs(draw_img_save)
|
150 |
+
for image_file in image_file_list:
|
151 |
+
img, flag, _ = check_and_read(image_file)
|
152 |
+
if not flag:
|
153 |
+
img = cv2.imread(image_file)
|
154 |
+
if img is None:
|
155 |
+
logger.info("error in loading image:{}".format(image_file))
|
156 |
+
continue
|
157 |
+
points, strs, elapse = text_detector(img)
|
158 |
+
if count > 0:
|
159 |
+
total_time += elapse
|
160 |
+
count += 1
|
161 |
+
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
162 |
+
src_im = utility.draw_e2e_res(points, strs, image_file)
|
163 |
+
img_name_pure = os.path.split(image_file)[-1]
|
164 |
+
img_path = os.path.join(draw_img_save,
|
165 |
+
"e2e_res_{}".format(img_name_pure))
|
166 |
+
cv2.imwrite(img_path, src_im)
|
167 |
+
logger.info("The visualized image saved in {}".format(img_path))
|
168 |
+
if count > 1:
|
169 |
+
logger.info("Avg Time: {}".format(total_time / (count - 1)))
|
tools/infer/predict_rec.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from PIL import Image
|
17 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(__dir__)
|
19 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
20 |
+
|
21 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import numpy as np
|
25 |
+
import math
|
26 |
+
import time
|
27 |
+
import traceback
|
28 |
+
import paddle
|
29 |
+
|
30 |
+
import tools.infer.utility as utility
|
31 |
+
from ppocr.postprocess import build_post_process
|
32 |
+
from ppocr.utils.logging import get_logger
|
33 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
34 |
+
|
35 |
+
logger = get_logger()
|
36 |
+
|
37 |
+
|
38 |
+
class TextRecognizer(object):
|
39 |
+
def __init__(self, args):
|
40 |
+
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
41 |
+
self.rec_batch_num = args.rec_batch_num
|
42 |
+
self.rec_algorithm = args.rec_algorithm
|
43 |
+
postprocess_params = {
|
44 |
+
'name': 'CTCLabelDecode',
|
45 |
+
"character_dict_path": args.rec_char_dict_path,
|
46 |
+
"use_space_char": args.use_space_char
|
47 |
+
}
|
48 |
+
if self.rec_algorithm == "SRN":
|
49 |
+
postprocess_params = {
|
50 |
+
'name': 'SRNLabelDecode',
|
51 |
+
"character_dict_path": args.rec_char_dict_path,
|
52 |
+
"use_space_char": args.use_space_char
|
53 |
+
}
|
54 |
+
elif self.rec_algorithm == "RARE":
|
55 |
+
postprocess_params = {
|
56 |
+
'name': 'AttnLabelDecode',
|
57 |
+
"character_dict_path": args.rec_char_dict_path,
|
58 |
+
"use_space_char": args.use_space_char
|
59 |
+
}
|
60 |
+
elif self.rec_algorithm == 'NRTR':
|
61 |
+
postprocess_params = {
|
62 |
+
'name': 'NRTRLabelDecode',
|
63 |
+
"character_dict_path": args.rec_char_dict_path,
|
64 |
+
"use_space_char": args.use_space_char
|
65 |
+
}
|
66 |
+
elif self.rec_algorithm == "SAR":
|
67 |
+
postprocess_params = {
|
68 |
+
'name': 'SARLabelDecode',
|
69 |
+
"character_dict_path": args.rec_char_dict_path,
|
70 |
+
"use_space_char": args.use_space_char
|
71 |
+
}
|
72 |
+
elif self.rec_algorithm == "VisionLAN":
|
73 |
+
postprocess_params = {
|
74 |
+
'name': 'VLLabelDecode',
|
75 |
+
"character_dict_path": args.rec_char_dict_path,
|
76 |
+
"use_space_char": args.use_space_char
|
77 |
+
}
|
78 |
+
elif self.rec_algorithm == 'ViTSTR':
|
79 |
+
postprocess_params = {
|
80 |
+
'name': 'ViTSTRLabelDecode',
|
81 |
+
"character_dict_path": args.rec_char_dict_path,
|
82 |
+
"use_space_char": args.use_space_char
|
83 |
+
}
|
84 |
+
elif self.rec_algorithm == 'ABINet':
|
85 |
+
postprocess_params = {
|
86 |
+
'name': 'ABINetLabelDecode',
|
87 |
+
"character_dict_path": args.rec_char_dict_path,
|
88 |
+
"use_space_char": args.use_space_char
|
89 |
+
}
|
90 |
+
elif self.rec_algorithm == "SPIN":
|
91 |
+
postprocess_params = {
|
92 |
+
'name': 'SPINLabelDecode',
|
93 |
+
"character_dict_path": args.rec_char_dict_path,
|
94 |
+
"use_space_char": args.use_space_char
|
95 |
+
}
|
96 |
+
elif self.rec_algorithm == "RobustScanner":
|
97 |
+
postprocess_params = {
|
98 |
+
'name': 'SARLabelDecode',
|
99 |
+
"character_dict_path": args.rec_char_dict_path,
|
100 |
+
"use_space_char": args.use_space_char,
|
101 |
+
"rm_symbol": True
|
102 |
+
}
|
103 |
+
elif self.rec_algorithm == 'RFL':
|
104 |
+
postprocess_params = {
|
105 |
+
'name': 'RFLLabelDecode',
|
106 |
+
"character_dict_path": None,
|
107 |
+
"use_space_char": args.use_space_char
|
108 |
+
}
|
109 |
+
elif self.rec_algorithm == "PREN":
|
110 |
+
postprocess_params = {'name': 'PRENLabelDecode'}
|
111 |
+
elif self.rec_algorithm == "CAN":
|
112 |
+
self.inverse = args.rec_image_inverse
|
113 |
+
postprocess_params = {
|
114 |
+
'name': 'CANLabelDecode',
|
115 |
+
"character_dict_path": args.rec_char_dict_path,
|
116 |
+
"use_space_char": args.use_space_char
|
117 |
+
}
|
118 |
+
self.postprocess_op = build_post_process(postprocess_params)
|
119 |
+
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
120 |
+
utility.create_predictor(args, 'rec', logger)
|
121 |
+
self.benchmark = args.benchmark
|
122 |
+
self.use_onnx = args.use_onnx
|
123 |
+
if args.benchmark:
|
124 |
+
import auto_log
|
125 |
+
pid = os.getpid()
|
126 |
+
gpu_id = utility.get_infer_gpuid()
|
127 |
+
self.autolog = auto_log.AutoLogger(
|
128 |
+
model_name="rec",
|
129 |
+
model_precision=args.precision,
|
130 |
+
batch_size=args.rec_batch_num,
|
131 |
+
data_shape="dynamic",
|
132 |
+
save_path=None, #args.save_log_path,
|
133 |
+
inference_config=self.config,
|
134 |
+
pids=pid,
|
135 |
+
process_name=None,
|
136 |
+
gpu_ids=gpu_id if args.use_gpu else None,
|
137 |
+
time_keys=[
|
138 |
+
'preprocess_time', 'inference_time', 'postprocess_time'
|
139 |
+
],
|
140 |
+
warmup=0,
|
141 |
+
logger=logger)
|
142 |
+
|
143 |
+
def resize_norm_img(self, img, max_wh_ratio):
|
144 |
+
imgC, imgH, imgW = self.rec_image_shape
|
145 |
+
if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
|
146 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
147 |
+
# return padding_im
|
148 |
+
image_pil = Image.fromarray(np.uint8(img))
|
149 |
+
if self.rec_algorithm == 'ViTSTR':
|
150 |
+
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
|
151 |
+
else:
|
152 |
+
img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
|
153 |
+
img = np.array(img)
|
154 |
+
norm_img = np.expand_dims(img, -1)
|
155 |
+
norm_img = norm_img.transpose((2, 0, 1))
|
156 |
+
if self.rec_algorithm == 'ViTSTR':
|
157 |
+
norm_img = norm_img.astype(np.float32) / 255.
|
158 |
+
else:
|
159 |
+
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
160 |
+
return norm_img
|
161 |
+
elif self.rec_algorithm == 'RFL':
|
162 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
163 |
+
resized_image = cv2.resize(
|
164 |
+
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
165 |
+
resized_image = resized_image.astype('float32')
|
166 |
+
resized_image = resized_image / 255
|
167 |
+
resized_image = resized_image[np.newaxis, :]
|
168 |
+
resized_image -= 0.5
|
169 |
+
resized_image /= 0.5
|
170 |
+
return resized_image
|
171 |
+
|
172 |
+
assert imgC == img.shape[2]
|
173 |
+
imgW = int((imgH * max_wh_ratio))
|
174 |
+
if self.use_onnx:
|
175 |
+
w = self.input_tensor.shape[3:][0]
|
176 |
+
if w is not None and w > 0:
|
177 |
+
imgW = w
|
178 |
+
|
179 |
+
h, w = img.shape[:2]
|
180 |
+
ratio = w / float(h)
|
181 |
+
if math.ceil(imgH * ratio) > imgW:
|
182 |
+
resized_w = imgW
|
183 |
+
else:
|
184 |
+
resized_w = int(math.ceil(imgH * ratio))
|
185 |
+
if self.rec_algorithm == 'RARE':
|
186 |
+
if resized_w > self.rec_image_shape[2]:
|
187 |
+
resized_w = self.rec_image_shape[2]
|
188 |
+
imgW = self.rec_image_shape[2]
|
189 |
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
190 |
+
resized_image = resized_image.astype('float32')
|
191 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
192 |
+
resized_image -= 0.5
|
193 |
+
resized_image /= 0.5
|
194 |
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
195 |
+
padding_im[:, :, 0:resized_w] = resized_image
|
196 |
+
return padding_im
|
197 |
+
|
198 |
+
def resize_norm_img_vl(self, img, image_shape):
|
199 |
+
|
200 |
+
imgC, imgH, imgW = image_shape
|
201 |
+
img = img[:, :, ::-1] # bgr2rgb
|
202 |
+
resized_image = cv2.resize(
|
203 |
+
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
204 |
+
resized_image = resized_image.astype('float32')
|
205 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
206 |
+
return resized_image
|
207 |
+
|
208 |
+
def resize_norm_img_srn(self, img, image_shape):
|
209 |
+
imgC, imgH, imgW = image_shape
|
210 |
+
|
211 |
+
img_black = np.zeros((imgH, imgW))
|
212 |
+
im_hei = img.shape[0]
|
213 |
+
im_wid = img.shape[1]
|
214 |
+
|
215 |
+
if im_wid <= im_hei * 1:
|
216 |
+
img_new = cv2.resize(img, (imgH * 1, imgH))
|
217 |
+
elif im_wid <= im_hei * 2:
|
218 |
+
img_new = cv2.resize(img, (imgH * 2, imgH))
|
219 |
+
elif im_wid <= im_hei * 3:
|
220 |
+
img_new = cv2.resize(img, (imgH * 3, imgH))
|
221 |
+
else:
|
222 |
+
img_new = cv2.resize(img, (imgW, imgH))
|
223 |
+
|
224 |
+
img_np = np.asarray(img_new)
|
225 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
226 |
+
img_black[:, 0:img_np.shape[1]] = img_np
|
227 |
+
img_black = img_black[:, :, np.newaxis]
|
228 |
+
|
229 |
+
row, col, c = img_black.shape
|
230 |
+
c = 1
|
231 |
+
|
232 |
+
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
233 |
+
|
234 |
+
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
235 |
+
|
236 |
+
imgC, imgH, imgW = image_shape
|
237 |
+
feature_dim = int((imgH / 8) * (imgW / 8))
|
238 |
+
|
239 |
+
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
240 |
+
(feature_dim, 1)).astype('int64')
|
241 |
+
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
242 |
+
(max_text_length, 1)).astype('int64')
|
243 |
+
|
244 |
+
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
245 |
+
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
246 |
+
[-1, 1, max_text_length, max_text_length])
|
247 |
+
gsrm_slf_attn_bias1 = np.tile(
|
248 |
+
gsrm_slf_attn_bias1,
|
249 |
+
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
250 |
+
|
251 |
+
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
252 |
+
[-1, 1, max_text_length, max_text_length])
|
253 |
+
gsrm_slf_attn_bias2 = np.tile(
|
254 |
+
gsrm_slf_attn_bias2,
|
255 |
+
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
256 |
+
|
257 |
+
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
258 |
+
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
259 |
+
|
260 |
+
return [
|
261 |
+
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
262 |
+
gsrm_slf_attn_bias2
|
263 |
+
]
|
264 |
+
|
265 |
+
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
266 |
+
norm_img = self.resize_norm_img_srn(img, image_shape)
|
267 |
+
norm_img = norm_img[np.newaxis, :]
|
268 |
+
|
269 |
+
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
270 |
+
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
271 |
+
|
272 |
+
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
273 |
+
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
274 |
+
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
275 |
+
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
276 |
+
|
277 |
+
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
278 |
+
gsrm_slf_attn_bias2)
|
279 |
+
|
280 |
+
def resize_norm_img_sar(self, img, image_shape,
|
281 |
+
width_downsample_ratio=0.25):
|
282 |
+
imgC, imgH, imgW_min, imgW_max = image_shape
|
283 |
+
h = img.shape[0]
|
284 |
+
w = img.shape[1]
|
285 |
+
valid_ratio = 1.0
|
286 |
+
# make sure new_width is an integral multiple of width_divisor.
|
287 |
+
width_divisor = int(1 / width_downsample_ratio)
|
288 |
+
# resize
|
289 |
+
ratio = w / float(h)
|
290 |
+
resize_w = math.ceil(imgH * ratio)
|
291 |
+
if resize_w % width_divisor != 0:
|
292 |
+
resize_w = round(resize_w / width_divisor) * width_divisor
|
293 |
+
if imgW_min is not None:
|
294 |
+
resize_w = max(imgW_min, resize_w)
|
295 |
+
if imgW_max is not None:
|
296 |
+
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
297 |
+
resize_w = min(imgW_max, resize_w)
|
298 |
+
resized_image = cv2.resize(img, (resize_w, imgH))
|
299 |
+
resized_image = resized_image.astype('float32')
|
300 |
+
# norm
|
301 |
+
if image_shape[0] == 1:
|
302 |
+
resized_image = resized_image / 255
|
303 |
+
resized_image = resized_image[np.newaxis, :]
|
304 |
+
else:
|
305 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
306 |
+
resized_image -= 0.5
|
307 |
+
resized_image /= 0.5
|
308 |
+
resize_shape = resized_image.shape
|
309 |
+
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
310 |
+
padding_im[:, :, 0:resize_w] = resized_image
|
311 |
+
pad_shape = padding_im.shape
|
312 |
+
|
313 |
+
return padding_im, resize_shape, pad_shape, valid_ratio
|
314 |
+
|
315 |
+
def resize_norm_img_spin(self, img):
|
316 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
317 |
+
# return padding_im
|
318 |
+
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
319 |
+
img = np.array(img, np.float32)
|
320 |
+
img = np.expand_dims(img, -1)
|
321 |
+
img = img.transpose((2, 0, 1))
|
322 |
+
mean = [127.5]
|
323 |
+
std = [127.5]
|
324 |
+
mean = np.array(mean, dtype=np.float32)
|
325 |
+
std = np.array(std, dtype=np.float32)
|
326 |
+
mean = np.float32(mean.reshape(1, -1))
|
327 |
+
stdinv = 1 / np.float32(std.reshape(1, -1))
|
328 |
+
img -= mean
|
329 |
+
img *= stdinv
|
330 |
+
return img
|
331 |
+
|
332 |
+
def resize_norm_img_svtr(self, img, image_shape):
|
333 |
+
|
334 |
+
imgC, imgH, imgW = image_shape
|
335 |
+
resized_image = cv2.resize(
|
336 |
+
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
337 |
+
resized_image = resized_image.astype('float32')
|
338 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
339 |
+
resized_image -= 0.5
|
340 |
+
resized_image /= 0.5
|
341 |
+
return resized_image
|
342 |
+
|
343 |
+
def resize_norm_img_abinet(self, img, image_shape):
|
344 |
+
|
345 |
+
imgC, imgH, imgW = image_shape
|
346 |
+
|
347 |
+
resized_image = cv2.resize(
|
348 |
+
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
349 |
+
resized_image = resized_image.astype('float32')
|
350 |
+
resized_image = resized_image / 255.
|
351 |
+
|
352 |
+
mean = np.array([0.485, 0.456, 0.406])
|
353 |
+
std = np.array([0.229, 0.224, 0.225])
|
354 |
+
resized_image = (
|
355 |
+
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
356 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
357 |
+
resized_image = resized_image.astype('float32')
|
358 |
+
|
359 |
+
return resized_image
|
360 |
+
|
361 |
+
def norm_img_can(self, img, image_shape):
|
362 |
+
|
363 |
+
img = cv2.cvtColor(
|
364 |
+
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
365 |
+
|
366 |
+
if self.inverse:
|
367 |
+
img = 255 - img
|
368 |
+
|
369 |
+
if self.rec_image_shape[0] == 1:
|
370 |
+
h, w = img.shape
|
371 |
+
_, imgH, imgW = self.rec_image_shape
|
372 |
+
if h < imgH or w < imgW:
|
373 |
+
padding_h = max(imgH - h, 0)
|
374 |
+
padding_w = max(imgW - w, 0)
|
375 |
+
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
376 |
+
'constant',
|
377 |
+
constant_values=(255))
|
378 |
+
img = img_padded
|
379 |
+
|
380 |
+
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
381 |
+
img = img.astype('float32')
|
382 |
+
|
383 |
+
return img
|
384 |
+
|
385 |
+
def __call__(self, img_list):
|
386 |
+
img_num = len(img_list)
|
387 |
+
# Calculate the aspect ratio of all text bars
|
388 |
+
width_list = []
|
389 |
+
for img in img_list:
|
390 |
+
width_list.append(img.shape[1] / float(img.shape[0]))
|
391 |
+
# Sorting can speed up the recognition process
|
392 |
+
indices = np.argsort(np.array(width_list))
|
393 |
+
rec_res = [['', 0.0]] * img_num
|
394 |
+
batch_num = self.rec_batch_num
|
395 |
+
st = time.time()
|
396 |
+
if self.benchmark:
|
397 |
+
self.autolog.times.start()
|
398 |
+
for beg_img_no in range(0, img_num, batch_num):
|
399 |
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
400 |
+
norm_img_batch = []
|
401 |
+
if self.rec_algorithm == "SRN":
|
402 |
+
encoder_word_pos_list = []
|
403 |
+
gsrm_word_pos_list = []
|
404 |
+
gsrm_slf_attn_bias1_list = []
|
405 |
+
gsrm_slf_attn_bias2_list = []
|
406 |
+
if self.rec_algorithm == "SAR":
|
407 |
+
valid_ratios = []
|
408 |
+
imgC, imgH, imgW = self.rec_image_shape[:3]
|
409 |
+
max_wh_ratio = imgW / imgH
|
410 |
+
# max_wh_ratio = 0
|
411 |
+
for ino in range(beg_img_no, end_img_no):
|
412 |
+
h, w = img_list[indices[ino]].shape[0:2]
|
413 |
+
wh_ratio = w * 1.0 / h
|
414 |
+
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
415 |
+
for ino in range(beg_img_no, end_img_no):
|
416 |
+
if self.rec_algorithm == "SAR":
|
417 |
+
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
418 |
+
img_list[indices[ino]], self.rec_image_shape)
|
419 |
+
norm_img = norm_img[np.newaxis, :]
|
420 |
+
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
421 |
+
valid_ratios.append(valid_ratio)
|
422 |
+
norm_img_batch.append(norm_img)
|
423 |
+
elif self.rec_algorithm == "SRN":
|
424 |
+
norm_img = self.process_image_srn(
|
425 |
+
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
426 |
+
encoder_word_pos_list.append(norm_img[1])
|
427 |
+
gsrm_word_pos_list.append(norm_img[2])
|
428 |
+
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
429 |
+
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
430 |
+
norm_img_batch.append(norm_img[0])
|
431 |
+
elif self.rec_algorithm == "SVTR":
|
432 |
+
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
|
433 |
+
self.rec_image_shape)
|
434 |
+
norm_img = norm_img[np.newaxis, :]
|
435 |
+
norm_img_batch.append(norm_img)
|
436 |
+
elif self.rec_algorithm in ["VisionLAN", "PREN"]:
|
437 |
+
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
|
438 |
+
self.rec_image_shape)
|
439 |
+
norm_img = norm_img[np.newaxis, :]
|
440 |
+
norm_img_batch.append(norm_img)
|
441 |
+
elif self.rec_algorithm == 'SPIN':
|
442 |
+
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
|
443 |
+
norm_img = norm_img[np.newaxis, :]
|
444 |
+
norm_img_batch.append(norm_img)
|
445 |
+
elif self.rec_algorithm == "ABINet":
|
446 |
+
norm_img = self.resize_norm_img_abinet(
|
447 |
+
img_list[indices[ino]], self.rec_image_shape)
|
448 |
+
norm_img = norm_img[np.newaxis, :]
|
449 |
+
norm_img_batch.append(norm_img)
|
450 |
+
elif self.rec_algorithm == "RobustScanner":
|
451 |
+
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
452 |
+
img_list[indices[ino]],
|
453 |
+
self.rec_image_shape,
|
454 |
+
width_downsample_ratio=0.25)
|
455 |
+
norm_img = norm_img[np.newaxis, :]
|
456 |
+
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
457 |
+
valid_ratios = []
|
458 |
+
valid_ratios.append(valid_ratio)
|
459 |
+
norm_img_batch.append(norm_img)
|
460 |
+
word_positions_list = []
|
461 |
+
word_positions = np.array(range(0, 40)).astype('int64')
|
462 |
+
word_positions = np.expand_dims(word_positions, axis=0)
|
463 |
+
word_positions_list.append(word_positions)
|
464 |
+
elif self.rec_algorithm == "CAN":
|
465 |
+
norm_img = self.norm_img_can(img_list[indices[ino]],
|
466 |
+
max_wh_ratio)
|
467 |
+
norm_img = norm_img[np.newaxis, :]
|
468 |
+
norm_img_batch.append(norm_img)
|
469 |
+
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
|
470 |
+
word_label = np.ones([1, 36], dtype='int64')
|
471 |
+
norm_img_mask_batch = []
|
472 |
+
word_label_list = []
|
473 |
+
norm_img_mask_batch.append(norm_image_mask)
|
474 |
+
word_label_list.append(word_label)
|
475 |
+
else:
|
476 |
+
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
477 |
+
max_wh_ratio)
|
478 |
+
norm_img = norm_img[np.newaxis, :]
|
479 |
+
norm_img_batch.append(norm_img)
|
480 |
+
norm_img_batch = np.concatenate(norm_img_batch)
|
481 |
+
norm_img_batch = norm_img_batch.copy()
|
482 |
+
if self.benchmark:
|
483 |
+
self.autolog.times.stamp()
|
484 |
+
|
485 |
+
if self.rec_algorithm == "SRN":
|
486 |
+
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
487 |
+
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
488 |
+
gsrm_slf_attn_bias1_list = np.concatenate(
|
489 |
+
gsrm_slf_attn_bias1_list)
|
490 |
+
gsrm_slf_attn_bias2_list = np.concatenate(
|
491 |
+
gsrm_slf_attn_bias2_list)
|
492 |
+
|
493 |
+
inputs = [
|
494 |
+
norm_img_batch,
|
495 |
+
encoder_word_pos_list,
|
496 |
+
gsrm_word_pos_list,
|
497 |
+
gsrm_slf_attn_bias1_list,
|
498 |
+
gsrm_slf_attn_bias2_list,
|
499 |
+
]
|
500 |
+
if self.use_onnx:
|
501 |
+
input_dict = {}
|
502 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
503 |
+
outputs = self.predictor.run(self.output_tensors,
|
504 |
+
input_dict)
|
505 |
+
preds = {"predict": outputs[2]}
|
506 |
+
else:
|
507 |
+
input_names = self.predictor.get_input_names()
|
508 |
+
for i in range(len(input_names)):
|
509 |
+
input_tensor = self.predictor.get_input_handle(
|
510 |
+
input_names[i])
|
511 |
+
input_tensor.copy_from_cpu(inputs[i])
|
512 |
+
self.predictor.run()
|
513 |
+
outputs = []
|
514 |
+
for output_tensor in self.output_tensors:
|
515 |
+
output = output_tensor.copy_to_cpu()
|
516 |
+
outputs.append(output)
|
517 |
+
if self.benchmark:
|
518 |
+
self.autolog.times.stamp()
|
519 |
+
preds = {"predict": outputs[2]}
|
520 |
+
elif self.rec_algorithm == "SAR":
|
521 |
+
valid_ratios = np.concatenate(valid_ratios)
|
522 |
+
inputs = [
|
523 |
+
norm_img_batch,
|
524 |
+
np.array(
|
525 |
+
[valid_ratios], dtype=np.float32),
|
526 |
+
]
|
527 |
+
if self.use_onnx:
|
528 |
+
input_dict = {}
|
529 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
530 |
+
outputs = self.predictor.run(self.output_tensors,
|
531 |
+
input_dict)
|
532 |
+
preds = outputs[0]
|
533 |
+
else:
|
534 |
+
input_names = self.predictor.get_input_names()
|
535 |
+
for i in range(len(input_names)):
|
536 |
+
input_tensor = self.predictor.get_input_handle(
|
537 |
+
input_names[i])
|
538 |
+
input_tensor.copy_from_cpu(inputs[i])
|
539 |
+
self.predictor.run()
|
540 |
+
outputs = []
|
541 |
+
for output_tensor in self.output_tensors:
|
542 |
+
output = output_tensor.copy_to_cpu()
|
543 |
+
outputs.append(output)
|
544 |
+
if self.benchmark:
|
545 |
+
self.autolog.times.stamp()
|
546 |
+
preds = outputs[0]
|
547 |
+
elif self.rec_algorithm == "RobustScanner":
|
548 |
+
valid_ratios = np.concatenate(valid_ratios)
|
549 |
+
word_positions_list = np.concatenate(word_positions_list)
|
550 |
+
inputs = [norm_img_batch, valid_ratios, word_positions_list]
|
551 |
+
|
552 |
+
if self.use_onnx:
|
553 |
+
input_dict = {}
|
554 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
555 |
+
outputs = self.predictor.run(self.output_tensors,
|
556 |
+
input_dict)
|
557 |
+
preds = outputs[0]
|
558 |
+
else:
|
559 |
+
input_names = self.predictor.get_input_names()
|
560 |
+
for i in range(len(input_names)):
|
561 |
+
input_tensor = self.predictor.get_input_handle(
|
562 |
+
input_names[i])
|
563 |
+
input_tensor.copy_from_cpu(inputs[i])
|
564 |
+
self.predictor.run()
|
565 |
+
outputs = []
|
566 |
+
for output_tensor in self.output_tensors:
|
567 |
+
output = output_tensor.copy_to_cpu()
|
568 |
+
outputs.append(output)
|
569 |
+
if self.benchmark:
|
570 |
+
self.autolog.times.stamp()
|
571 |
+
preds = outputs[0]
|
572 |
+
elif self.rec_algorithm == "CAN":
|
573 |
+
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
|
574 |
+
word_label_list = np.concatenate(word_label_list)
|
575 |
+
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
|
576 |
+
if self.use_onnx:
|
577 |
+
input_dict = {}
|
578 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
579 |
+
outputs = self.predictor.run(self.output_tensors,
|
580 |
+
input_dict)
|
581 |
+
preds = outputs
|
582 |
+
else:
|
583 |
+
input_names = self.predictor.get_input_names()
|
584 |
+
input_tensor = []
|
585 |
+
for i in range(len(input_names)):
|
586 |
+
input_tensor_i = self.predictor.get_input_handle(
|
587 |
+
input_names[i])
|
588 |
+
input_tensor_i.copy_from_cpu(inputs[i])
|
589 |
+
input_tensor.append(input_tensor_i)
|
590 |
+
self.input_tensor = input_tensor
|
591 |
+
self.predictor.run()
|
592 |
+
outputs = []
|
593 |
+
for output_tensor in self.output_tensors:
|
594 |
+
output = output_tensor.copy_to_cpu()
|
595 |
+
outputs.append(output)
|
596 |
+
if self.benchmark:
|
597 |
+
self.autolog.times.stamp()
|
598 |
+
preds = outputs
|
599 |
+
else:
|
600 |
+
if self.use_onnx:
|
601 |
+
input_dict = {}
|
602 |
+
input_dict[self.input_tensor.name] = norm_img_batch
|
603 |
+
outputs = self.predictor.run(self.output_tensors,
|
604 |
+
input_dict)
|
605 |
+
preds = outputs[0]
|
606 |
+
else:
|
607 |
+
self.input_tensor.copy_from_cpu(norm_img_batch)
|
608 |
+
self.predictor.run()
|
609 |
+
outputs = []
|
610 |
+
for output_tensor in self.output_tensors:
|
611 |
+
output = output_tensor.copy_to_cpu()
|
612 |
+
outputs.append(output)
|
613 |
+
if self.benchmark:
|
614 |
+
self.autolog.times.stamp()
|
615 |
+
if len(outputs) != 1:
|
616 |
+
preds = outputs
|
617 |
+
else:
|
618 |
+
preds = outputs[0]
|
619 |
+
rec_result = self.postprocess_op(preds)
|
620 |
+
for rno in range(len(rec_result)):
|
621 |
+
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
622 |
+
if self.benchmark:
|
623 |
+
self.autolog.times.end(stamp=True)
|
624 |
+
return rec_res, time.time() - st
|
625 |
+
|
626 |
+
|
627 |
+
def main(args):
|
628 |
+
image_file_list = get_image_file_list(args.image_dir)
|
629 |
+
text_recognizer = TextRecognizer(args)
|
630 |
+
valid_image_file_list = []
|
631 |
+
img_list = []
|
632 |
+
|
633 |
+
logger.info(
|
634 |
+
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
635 |
+
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
|
636 |
+
)
|
637 |
+
# warmup 2 times
|
638 |
+
if args.warmup:
|
639 |
+
img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
|
640 |
+
for i in range(2):
|
641 |
+
res = text_recognizer([img] * int(args.rec_batch_num))
|
642 |
+
|
643 |
+
for image_file in image_file_list:
|
644 |
+
img, flag, _ = check_and_read(image_file)
|
645 |
+
if not flag:
|
646 |
+
img = cv2.imread(image_file)
|
647 |
+
if img is None:
|
648 |
+
logger.info("error in loading image:{}".format(image_file))
|
649 |
+
continue
|
650 |
+
valid_image_file_list.append(image_file)
|
651 |
+
img_list.append(img)
|
652 |
+
try:
|
653 |
+
rec_res, _ = text_recognizer(img_list)
|
654 |
+
|
655 |
+
except Exception as E:
|
656 |
+
logger.info(traceback.format_exc())
|
657 |
+
logger.info(E)
|
658 |
+
exit()
|
659 |
+
for ino in range(len(img_list)):
|
660 |
+
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
661 |
+
rec_res[ino]))
|
662 |
+
if args.benchmark:
|
663 |
+
text_recognizer.autolog.report()
|
664 |
+
|
665 |
+
|
666 |
+
if __name__ == "__main__":
|
667 |
+
main(utility.parse_args())
|
tools/infer/predict_sr.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from PIL import Image
|
17 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.insert(0, __dir__)
|
19 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
20 |
+
|
21 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import numpy as np
|
25 |
+
import math
|
26 |
+
import time
|
27 |
+
import traceback
|
28 |
+
import paddle
|
29 |
+
|
30 |
+
import tools.infer.utility as utility
|
31 |
+
from ppocr.postprocess import build_post_process
|
32 |
+
from ppocr.utils.logging import get_logger
|
33 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
34 |
+
|
35 |
+
logger = get_logger()
|
36 |
+
|
37 |
+
|
38 |
+
class TextSR(object):
|
39 |
+
def __init__(self, args):
|
40 |
+
self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
|
41 |
+
self.sr_batch_num = args.sr_batch_num
|
42 |
+
|
43 |
+
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
44 |
+
utility.create_predictor(args, 'sr', logger)
|
45 |
+
self.benchmark = args.benchmark
|
46 |
+
if args.benchmark:
|
47 |
+
import auto_log
|
48 |
+
pid = os.getpid()
|
49 |
+
gpu_id = utility.get_infer_gpuid()
|
50 |
+
self.autolog = auto_log.AutoLogger(
|
51 |
+
model_name="sr",
|
52 |
+
model_precision=args.precision,
|
53 |
+
batch_size=args.sr_batch_num,
|
54 |
+
data_shape="dynamic",
|
55 |
+
save_path=None, #args.save_log_path,
|
56 |
+
inference_config=self.config,
|
57 |
+
pids=pid,
|
58 |
+
process_name=None,
|
59 |
+
gpu_ids=gpu_id if args.use_gpu else None,
|
60 |
+
time_keys=[
|
61 |
+
'preprocess_time', 'inference_time', 'postprocess_time'
|
62 |
+
],
|
63 |
+
warmup=0,
|
64 |
+
logger=logger)
|
65 |
+
|
66 |
+
def resize_norm_img(self, img):
|
67 |
+
imgC, imgH, imgW = self.sr_image_shape
|
68 |
+
img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
|
69 |
+
img_numpy = np.array(img).astype("float32")
|
70 |
+
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
71 |
+
return img_numpy
|
72 |
+
|
73 |
+
def __call__(self, img_list):
|
74 |
+
img_num = len(img_list)
|
75 |
+
batch_num = self.sr_batch_num
|
76 |
+
st = time.time()
|
77 |
+
st = time.time()
|
78 |
+
all_result = [] * img_num
|
79 |
+
if self.benchmark:
|
80 |
+
self.autolog.times.start()
|
81 |
+
for beg_img_no in range(0, img_num, batch_num):
|
82 |
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
83 |
+
norm_img_batch = []
|
84 |
+
imgC, imgH, imgW = self.sr_image_shape
|
85 |
+
for ino in range(beg_img_no, end_img_no):
|
86 |
+
norm_img = self.resize_norm_img(img_list[ino])
|
87 |
+
norm_img = norm_img[np.newaxis, :]
|
88 |
+
norm_img_batch.append(norm_img)
|
89 |
+
|
90 |
+
norm_img_batch = np.concatenate(norm_img_batch)
|
91 |
+
norm_img_batch = norm_img_batch.copy()
|
92 |
+
if self.benchmark:
|
93 |
+
self.autolog.times.stamp()
|
94 |
+
self.input_tensor.copy_from_cpu(norm_img_batch)
|
95 |
+
self.predictor.run()
|
96 |
+
outputs = []
|
97 |
+
for output_tensor in self.output_tensors:
|
98 |
+
output = output_tensor.copy_to_cpu()
|
99 |
+
outputs.append(output)
|
100 |
+
if len(outputs) != 1:
|
101 |
+
preds = outputs
|
102 |
+
else:
|
103 |
+
preds = outputs[0]
|
104 |
+
all_result.append(outputs)
|
105 |
+
if self.benchmark:
|
106 |
+
self.autolog.times.end(stamp=True)
|
107 |
+
return all_result, time.time() - st
|
108 |
+
|
109 |
+
|
110 |
+
def main(args):
|
111 |
+
image_file_list = get_image_file_list(args.image_dir)
|
112 |
+
text_recognizer = TextSR(args)
|
113 |
+
valid_image_file_list = []
|
114 |
+
img_list = []
|
115 |
+
|
116 |
+
# warmup 2 times
|
117 |
+
if args.warmup:
|
118 |
+
img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
|
119 |
+
for i in range(2):
|
120 |
+
res = text_recognizer([img] * int(args.sr_batch_num))
|
121 |
+
|
122 |
+
for image_file in image_file_list:
|
123 |
+
img, flag, _ = check_and_read(image_file)
|
124 |
+
if not flag:
|
125 |
+
img = Image.open(image_file).convert("RGB")
|
126 |
+
if img is None:
|
127 |
+
logger.info("error in loading image:{}".format(image_file))
|
128 |
+
continue
|
129 |
+
valid_image_file_list.append(image_file)
|
130 |
+
img_list.append(img)
|
131 |
+
try:
|
132 |
+
preds, _ = text_recognizer(img_list)
|
133 |
+
for beg_no in range(len(preds)):
|
134 |
+
sr_img = preds[beg_no][1]
|
135 |
+
lr_img = preds[beg_no][0]
|
136 |
+
for i in (range(sr_img.shape[0])):
|
137 |
+
fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
|
138 |
+
fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
|
139 |
+
img_name_pure = os.path.split(valid_image_file_list[
|
140 |
+
beg_no * args.sr_batch_num + i])[-1]
|
141 |
+
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
|
142 |
+
fm_sr[:, :, ::-1])
|
143 |
+
logger.info("The visualized image saved in infer_result/sr_{}".
|
144 |
+
format(img_name_pure))
|
145 |
+
|
146 |
+
except Exception as E:
|
147 |
+
logger.info(traceback.format_exc())
|
148 |
+
logger.info(E)
|
149 |
+
exit()
|
150 |
+
if args.benchmark:
|
151 |
+
text_recognizer.autolog.report()
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
main(utility.parse_args())
|
tools/infer/predict_system.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import subprocess
|
17 |
+
|
18 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
sys.path.append(__dir__)
|
20 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
21 |
+
|
22 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
23 |
+
|
24 |
+
import cv2
|
25 |
+
import copy
|
26 |
+
import numpy as np
|
27 |
+
import json
|
28 |
+
import time
|
29 |
+
import logging
|
30 |
+
from PIL import Image
|
31 |
+
import tools.infer.utility as utility
|
32 |
+
import tools.infer.predict_rec as predict_rec
|
33 |
+
import tools.infer.predict_det as predict_det
|
34 |
+
import tools.infer.predict_cls as predict_cls
|
35 |
+
from ppocr.utils.utility import get_image_file_list, check_and_read
|
36 |
+
from ppocr.utils.logging import get_logger
|
37 |
+
from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
|
38 |
+
logger = get_logger()
|
39 |
+
|
40 |
+
|
41 |
+
class TextSystem(object):
|
42 |
+
def __init__(self, args):
|
43 |
+
if not args.show_log:
|
44 |
+
logger.setLevel(logging.INFO)
|
45 |
+
|
46 |
+
self.text_detector = predict_det.TextDetector(args)
|
47 |
+
self.text_recognizer = predict_rec.TextRecognizer(args)
|
48 |
+
self.use_angle_cls = args.use_angle_cls
|
49 |
+
self.drop_score = args.drop_score
|
50 |
+
if self.use_angle_cls:
|
51 |
+
self.text_classifier = predict_cls.TextClassifier(args)
|
52 |
+
|
53 |
+
self.args = args
|
54 |
+
self.crop_image_res_index = 0
|
55 |
+
|
56 |
+
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
57 |
+
os.makedirs(output_dir, exist_ok=True)
|
58 |
+
bbox_num = len(img_crop_list)
|
59 |
+
for bno in range(bbox_num):
|
60 |
+
cv2.imwrite(
|
61 |
+
os.path.join(output_dir,
|
62 |
+
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
|
63 |
+
img_crop_list[bno])
|
64 |
+
logger.debug(f"{bno}, {rec_res[bno]}")
|
65 |
+
self.crop_image_res_index += bbox_num
|
66 |
+
|
67 |
+
def __call__(self, img, cls=True):
|
68 |
+
time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
|
69 |
+
start = time.time()
|
70 |
+
ori_im = img.copy()
|
71 |
+
dt_boxes, elapse = self.text_detector(img)
|
72 |
+
time_dict['det'] = elapse
|
73 |
+
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
74 |
+
len(dt_boxes), elapse))
|
75 |
+
if dt_boxes is None:
|
76 |
+
return None, None
|
77 |
+
img_crop_list = []
|
78 |
+
|
79 |
+
dt_boxes = sorted_boxes(dt_boxes)
|
80 |
+
|
81 |
+
for bno in range(len(dt_boxes)):
|
82 |
+
tmp_box = copy.deepcopy(dt_boxes[bno])
|
83 |
+
if self.args.det_box_type == "quad":
|
84 |
+
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
85 |
+
else:
|
86 |
+
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
87 |
+
img_crop_list.append(img_crop)
|
88 |
+
if self.use_angle_cls and cls:
|
89 |
+
img_crop_list, angle_list, elapse = self.text_classifier(
|
90 |
+
img_crop_list)
|
91 |
+
time_dict['cls'] = elapse
|
92 |
+
logger.debug("cls num : {}, elapse : {}".format(
|
93 |
+
len(img_crop_list), elapse))
|
94 |
+
|
95 |
+
rec_res, elapse = self.text_recognizer(img_crop_list)
|
96 |
+
time_dict['rec'] = elapse
|
97 |
+
logger.debug("rec_res num : {}, elapse : {}".format(
|
98 |
+
len(rec_res), elapse))
|
99 |
+
if self.args.save_crop_res:
|
100 |
+
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
|
101 |
+
rec_res)
|
102 |
+
filter_boxes, filter_rec_res = [], []
|
103 |
+
for box, rec_result in zip(dt_boxes, rec_res):
|
104 |
+
text, score = rec_result
|
105 |
+
if score >= self.drop_score:
|
106 |
+
filter_boxes.append(box)
|
107 |
+
filter_rec_res.append(rec_result)
|
108 |
+
end = time.time()
|
109 |
+
time_dict['all'] = end - start
|
110 |
+
return filter_boxes, filter_rec_res, time_dict
|
111 |
+
|
112 |
+
|
113 |
+
def sorted_boxes(dt_boxes):
|
114 |
+
"""
|
115 |
+
Sort text boxes in order from top to bottom, left to right
|
116 |
+
args:
|
117 |
+
dt_boxes(array):detected text boxes with shape [4, 2]
|
118 |
+
return:
|
119 |
+
sorted boxes(array) with shape [4, 2]
|
120 |
+
"""
|
121 |
+
num_boxes = dt_boxes.shape[0]
|
122 |
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
123 |
+
_boxes = list(sorted_boxes)
|
124 |
+
|
125 |
+
for i in range(num_boxes - 1):
|
126 |
+
for j in range(i, -1, -1):
|
127 |
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
128 |
+
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
129 |
+
tmp = _boxes[j]
|
130 |
+
_boxes[j] = _boxes[j + 1]
|
131 |
+
_boxes[j + 1] = tmp
|
132 |
+
else:
|
133 |
+
break
|
134 |
+
return _boxes
|
135 |
+
|
136 |
+
|
137 |
+
def main(args):
|
138 |
+
image_file_list = get_image_file_list(args.image_dir)
|
139 |
+
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
140 |
+
text_sys = TextSystem(args)
|
141 |
+
is_visualize = True
|
142 |
+
font_path = args.vis_font_path
|
143 |
+
drop_score = args.drop_score
|
144 |
+
draw_img_save_dir = args.draw_img_save_dir
|
145 |
+
os.makedirs(draw_img_save_dir, exist_ok=True)
|
146 |
+
save_results = []
|
147 |
+
|
148 |
+
logger.info(
|
149 |
+
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
150 |
+
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
|
151 |
+
)
|
152 |
+
|
153 |
+
# warm up 10 times
|
154 |
+
if args.warmup:
|
155 |
+
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
156 |
+
for i in range(10):
|
157 |
+
res = text_sys(img)
|
158 |
+
|
159 |
+
total_time = 0
|
160 |
+
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
161 |
+
_st = time.time()
|
162 |
+
count = 0
|
163 |
+
for idx, image_file in enumerate(image_file_list):
|
164 |
+
|
165 |
+
img, flag_gif, flag_pdf = check_and_read(image_file)
|
166 |
+
if not flag_gif and not flag_pdf:
|
167 |
+
img = cv2.imread(image_file)
|
168 |
+
if not flag_pdf:
|
169 |
+
if img is None:
|
170 |
+
logger.debug("error in loading image:{}".format(image_file))
|
171 |
+
continue
|
172 |
+
imgs = [img]
|
173 |
+
else:
|
174 |
+
page_num = args.page_num
|
175 |
+
if page_num > len(img) or page_num == 0:
|
176 |
+
page_num = len(img)
|
177 |
+
imgs = img[:page_num]
|
178 |
+
for index, img in enumerate(imgs):
|
179 |
+
starttime = time.time()
|
180 |
+
dt_boxes, rec_res, time_dict = text_sys(img)
|
181 |
+
elapse = time.time() - starttime
|
182 |
+
total_time += elapse
|
183 |
+
if len(imgs) > 1:
|
184 |
+
logger.debug(
|
185 |
+
str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
|
186 |
+
% (image_file, elapse))
|
187 |
+
else:
|
188 |
+
logger.debug(
|
189 |
+
str(idx) + " Predict time of %s: %.3fs" % (image_file,
|
190 |
+
elapse))
|
191 |
+
for text, score in rec_res:
|
192 |
+
logger.debug("{}, {:.3f}".format(text, score))
|
193 |
+
|
194 |
+
res = [{
|
195 |
+
"transcription": rec_res[i][0],
|
196 |
+
"points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
|
197 |
+
} for i in range(len(dt_boxes))]
|
198 |
+
if len(imgs) > 1:
|
199 |
+
save_pred = os.path.basename(image_file) + '_' + str(
|
200 |
+
index) + "\t" + json.dumps(
|
201 |
+
res, ensure_ascii=False) + "\n"
|
202 |
+
else:
|
203 |
+
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
|
204 |
+
res, ensure_ascii=False) + "\n"
|
205 |
+
save_results.append(save_pred)
|
206 |
+
|
207 |
+
if is_visualize:
|
208 |
+
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
209 |
+
boxes = dt_boxes
|
210 |
+
txts = [rec_res[i][0] for i in range(len(rec_res))]
|
211 |
+
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
212 |
+
|
213 |
+
draw_img = draw_ocr_box_txt(
|
214 |
+
image,
|
215 |
+
boxes,
|
216 |
+
txts,
|
217 |
+
scores,
|
218 |
+
drop_score=drop_score,
|
219 |
+
font_path=font_path)
|
220 |
+
if flag_gif:
|
221 |
+
save_file = image_file[:-3] + "png"
|
222 |
+
elif flag_pdf:
|
223 |
+
save_file = image_file.replace('.pdf',
|
224 |
+
'_' + str(index) + '.png')
|
225 |
+
else:
|
226 |
+
save_file = image_file
|
227 |
+
cv2.imwrite(
|
228 |
+
os.path.join(draw_img_save_dir,
|
229 |
+
os.path.basename(save_file)),
|
230 |
+
draw_img[:, :, ::-1])
|
231 |
+
logger.debug("The visualized image saved in {}".format(
|
232 |
+
os.path.join(draw_img_save_dir, os.path.basename(
|
233 |
+
save_file))))
|
234 |
+
|
235 |
+
logger.info("The predict total time is {}".format(time.time() - _st))
|
236 |
+
if args.benchmark:
|
237 |
+
text_sys.text_detector.autolog.report()
|
238 |
+
text_sys.text_recognizer.autolog.report()
|
239 |
+
|
240 |
+
with open(
|
241 |
+
os.path.join(draw_img_save_dir, "system_results.txt"),
|
242 |
+
'w',
|
243 |
+
encoding='utf-8') as f:
|
244 |
+
f.writelines(save_results)
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
args = utility.parse_args()
|
249 |
+
if args.use_mp:
|
250 |
+
p_list = []
|
251 |
+
total_process_num = args.total_process_num
|
252 |
+
for process_id in range(total_process_num):
|
253 |
+
cmd = [sys.executable, "-u"] + sys.argv + [
|
254 |
+
"--process_id={}".format(process_id),
|
255 |
+
"--use_mp={}".format(False)
|
256 |
+
]
|
257 |
+
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
258 |
+
p_list.append(p)
|
259 |
+
for p in p_list:
|
260 |
+
p.wait()
|
261 |
+
else:
|
262 |
+
main(args)
|
tools/infer/utility.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import platform
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import paddle
|
22 |
+
from PIL import Image, ImageDraw, ImageFont
|
23 |
+
import math
|
24 |
+
from paddle import inference
|
25 |
+
import time
|
26 |
+
import random
|
27 |
+
from ppocr.utils.logging import get_logger
|
28 |
+
|
29 |
+
|
30 |
+
def str2bool(v):
|
31 |
+
return v.lower() in ("true", "t", "1")
|
32 |
+
|
33 |
+
|
34 |
+
def init_args():
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
# params for prediction engine
|
37 |
+
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
38 |
+
parser.add_argument("--use_xpu", type=str2bool, default=False)
|
39 |
+
parser.add_argument("--use_npu", type=str2bool, default=False)
|
40 |
+
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
41 |
+
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
42 |
+
parser.add_argument("--min_subgraph_size", type=int, default=15)
|
43 |
+
parser.add_argument("--precision", type=str, default="fp32")
|
44 |
+
parser.add_argument("--gpu_mem", type=int, default=500)
|
45 |
+
parser.add_argument("--gpu_id", type=int, default=0)
|
46 |
+
|
47 |
+
# params for text detector
|
48 |
+
parser.add_argument("--image_dir", type=str)
|
49 |
+
parser.add_argument("--page_num", type=int, default=0)
|
50 |
+
parser.add_argument("--det_algorithm", type=str, default='DB')
|
51 |
+
parser.add_argument("--det_model_dir", type=str)
|
52 |
+
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
53 |
+
parser.add_argument("--det_limit_type", type=str, default='max')
|
54 |
+
parser.add_argument("--det_box_type", type=str, default='quad')
|
55 |
+
|
56 |
+
# DB parmas
|
57 |
+
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
58 |
+
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
59 |
+
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
60 |
+
parser.add_argument("--max_batch_size", type=int, default=10)
|
61 |
+
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
62 |
+
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
63 |
+
|
64 |
+
# EAST parmas
|
65 |
+
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
66 |
+
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
67 |
+
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
68 |
+
|
69 |
+
# SAST parmas
|
70 |
+
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
71 |
+
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
72 |
+
|
73 |
+
# PSE parmas
|
74 |
+
parser.add_argument("--det_pse_thresh", type=float, default=0)
|
75 |
+
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
76 |
+
parser.add_argument("--det_pse_min_area", type=float, default=16)
|
77 |
+
parser.add_argument("--det_pse_scale", type=int, default=1)
|
78 |
+
|
79 |
+
# FCE parmas
|
80 |
+
parser.add_argument("--scales", type=list, default=[8, 16, 32])
|
81 |
+
parser.add_argument("--alpha", type=float, default=1.0)
|
82 |
+
parser.add_argument("--beta", type=float, default=1.0)
|
83 |
+
parser.add_argument("--fourier_degree", type=int, default=5)
|
84 |
+
|
85 |
+
# params for text recognizer
|
86 |
+
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
|
87 |
+
parser.add_argument("--rec_model_dir", type=str)
|
88 |
+
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
89 |
+
parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
|
90 |
+
parser.add_argument("--rec_batch_num", type=int, default=6)
|
91 |
+
parser.add_argument("--max_text_length", type=int, default=25)
|
92 |
+
parser.add_argument(
|
93 |
+
"--rec_char_dict_path",
|
94 |
+
type=str,
|
95 |
+
default="./ppocr/utils/ppocr_keys_v1.txt")
|
96 |
+
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
97 |
+
parser.add_argument(
|
98 |
+
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
|
99 |
+
parser.add_argument("--drop_score", type=float, default=0.5)
|
100 |
+
|
101 |
+
# params for e2e
|
102 |
+
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
|
103 |
+
parser.add_argument("--e2e_model_dir", type=str)
|
104 |
+
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
105 |
+
parser.add_argument("--e2e_limit_type", type=str, default='max')
|
106 |
+
|
107 |
+
# PGNet parmas
|
108 |
+
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
109 |
+
parser.add_argument(
|
110 |
+
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
|
111 |
+
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
|
112 |
+
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
|
113 |
+
|
114 |
+
# params for text classifier
|
115 |
+
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
116 |
+
parser.add_argument("--cls_model_dir", type=str)
|
117 |
+
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
118 |
+
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
119 |
+
parser.add_argument("--cls_batch_num", type=int, default=6)
|
120 |
+
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
121 |
+
|
122 |
+
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
123 |
+
parser.add_argument("--cpu_threads", type=int, default=10)
|
124 |
+
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
125 |
+
parser.add_argument("--warmup", type=str2bool, default=False)
|
126 |
+
|
127 |
+
# SR parmas
|
128 |
+
parser.add_argument("--sr_model_dir", type=str)
|
129 |
+
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
|
130 |
+
parser.add_argument("--sr_batch_num", type=int, default=1)
|
131 |
+
|
132 |
+
#
|
133 |
+
parser.add_argument(
|
134 |
+
"--draw_img_save_dir", type=str, default="./inference_results")
|
135 |
+
parser.add_argument("--save_crop_res", type=str2bool, default=False)
|
136 |
+
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
|
137 |
+
|
138 |
+
# multi-process
|
139 |
+
parser.add_argument("--use_mp", type=str2bool, default=False)
|
140 |
+
parser.add_argument("--total_process_num", type=int, default=1)
|
141 |
+
parser.add_argument("--process_id", type=int, default=0)
|
142 |
+
|
143 |
+
parser.add_argument("--benchmark", type=str2bool, default=False)
|
144 |
+
parser.add_argument("--save_log_path", type=str, default="./log_output/")
|
145 |
+
|
146 |
+
parser.add_argument("--show_log", type=str2bool, default=True)
|
147 |
+
parser.add_argument("--use_onnx", type=str2bool, default=False)
|
148 |
+
return parser
|
149 |
+
|
150 |
+
|
151 |
+
def parse_args():
|
152 |
+
parser = init_args()
|
153 |
+
return parser.parse_args()
|
154 |
+
|
155 |
+
|
156 |
+
def create_predictor(args, mode, logger):
|
157 |
+
if mode == "det":
|
158 |
+
model_dir = args.det_model_dir
|
159 |
+
elif mode == 'cls':
|
160 |
+
model_dir = args.cls_model_dir
|
161 |
+
elif mode == 'rec':
|
162 |
+
model_dir = args.rec_model_dir
|
163 |
+
elif mode == 'table':
|
164 |
+
model_dir = args.table_model_dir
|
165 |
+
elif mode == 'ser':
|
166 |
+
model_dir = args.ser_model_dir
|
167 |
+
elif mode == 're':
|
168 |
+
model_dir = args.re_model_dir
|
169 |
+
elif mode == "sr":
|
170 |
+
model_dir = args.sr_model_dir
|
171 |
+
elif mode == 'layout':
|
172 |
+
model_dir = args.layout_model_dir
|
173 |
+
else:
|
174 |
+
model_dir = args.e2e_model_dir
|
175 |
+
|
176 |
+
if model_dir is None:
|
177 |
+
logger.info("not find {} model file path {}".format(mode, model_dir))
|
178 |
+
sys.exit(0)
|
179 |
+
if args.use_onnx:
|
180 |
+
import onnxruntime as ort
|
181 |
+
model_file_path = model_dir
|
182 |
+
if not os.path.exists(model_file_path):
|
183 |
+
raise ValueError("not find model file path {}".format(
|
184 |
+
model_file_path))
|
185 |
+
sess = ort.InferenceSession(model_file_path)
|
186 |
+
return sess, sess.get_inputs()[0], None, None
|
187 |
+
|
188 |
+
else:
|
189 |
+
file_names = ['model', 'inference']
|
190 |
+
for file_name in file_names:
|
191 |
+
model_file_path = '{}/{}.pdmodel'.format(model_dir, file_name)
|
192 |
+
params_file_path = '{}/{}.pdiparams'.format(model_dir, file_name)
|
193 |
+
if os.path.exists(model_file_path) and os.path.exists(
|
194 |
+
params_file_path):
|
195 |
+
break
|
196 |
+
if not os.path.exists(model_file_path):
|
197 |
+
raise ValueError(
|
198 |
+
"not find model.pdmodel or inference.pdmodel in {}".format(
|
199 |
+
model_dir))
|
200 |
+
if not os.path.exists(params_file_path):
|
201 |
+
raise ValueError(
|
202 |
+
"not find model.pdiparams or inference.pdiparams in {}".format(
|
203 |
+
model_dir))
|
204 |
+
|
205 |
+
config = inference.Config(model_file_path, params_file_path)
|
206 |
+
|
207 |
+
if hasattr(args, 'precision'):
|
208 |
+
if args.precision == "fp16" and args.use_tensorrt:
|
209 |
+
precision = inference.PrecisionType.Half
|
210 |
+
elif args.precision == "int8":
|
211 |
+
precision = inference.PrecisionType.Int8
|
212 |
+
else:
|
213 |
+
precision = inference.PrecisionType.Float32
|
214 |
+
else:
|
215 |
+
precision = inference.PrecisionType.Float32
|
216 |
+
|
217 |
+
if args.use_gpu:
|
218 |
+
gpu_id = get_infer_gpuid()
|
219 |
+
if gpu_id is None:
|
220 |
+
logger.warning(
|
221 |
+
"GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
|
222 |
+
)
|
223 |
+
config.enable_use_gpu(args.gpu_mem, args.gpu_id)
|
224 |
+
if args.use_tensorrt:
|
225 |
+
config.enable_tensorrt_engine(
|
226 |
+
workspace_size=1 << 30,
|
227 |
+
precision_mode=precision,
|
228 |
+
max_batch_size=args.max_batch_size,
|
229 |
+
min_subgraph_size=args.
|
230 |
+
min_subgraph_size, # skip the minmum trt subgraph
|
231 |
+
use_calib_mode=False)
|
232 |
+
|
233 |
+
# collect shape
|
234 |
+
trt_shape_f = os.path.join(model_dir,
|
235 |
+
f"{mode}_trt_dynamic_shape.txt")
|
236 |
+
|
237 |
+
if not os.path.exists(trt_shape_f):
|
238 |
+
config.collect_shape_range_info(trt_shape_f)
|
239 |
+
logger.info(
|
240 |
+
f"collect dynamic shape info into : {trt_shape_f}")
|
241 |
+
try:
|
242 |
+
config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f,
|
243 |
+
True)
|
244 |
+
except Exception as E:
|
245 |
+
logger.info(E)
|
246 |
+
logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
|
247 |
+
|
248 |
+
elif args.use_npu:
|
249 |
+
config.enable_custom_device("npu")
|
250 |
+
elif args.use_xpu:
|
251 |
+
config.enable_xpu(10 * 1024 * 1024)
|
252 |
+
else:
|
253 |
+
config.disable_gpu()
|
254 |
+
if args.enable_mkldnn:
|
255 |
+
# cache 10 different shapes for mkldnn to avoid memory leak
|
256 |
+
config.set_mkldnn_cache_capacity(10)
|
257 |
+
config.enable_mkldnn()
|
258 |
+
if args.precision == "fp16":
|
259 |
+
config.enable_mkldnn_bfloat16()
|
260 |
+
if hasattr(args, "cpu_threads"):
|
261 |
+
config.set_cpu_math_library_num_threads(args.cpu_threads)
|
262 |
+
else:
|
263 |
+
# default cpu threads as 10
|
264 |
+
config.set_cpu_math_library_num_threads(10)
|
265 |
+
# enable memory optim
|
266 |
+
config.enable_memory_optim()
|
267 |
+
config.disable_glog_info()
|
268 |
+
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
269 |
+
config.delete_pass("matmul_transpose_reshape_fuse_pass")
|
270 |
+
if mode == 're':
|
271 |
+
config.delete_pass("simplify_with_basic_ops_pass")
|
272 |
+
if mode == 'table':
|
273 |
+
config.delete_pass("fc_fuse_pass") # not supported for table
|
274 |
+
config.switch_use_feed_fetch_ops(False)
|
275 |
+
config.switch_ir_optim(True)
|
276 |
+
|
277 |
+
# create predictor
|
278 |
+
predictor = inference.create_predictor(config)
|
279 |
+
input_names = predictor.get_input_names()
|
280 |
+
if mode in ['ser', 're']:
|
281 |
+
input_tensor = []
|
282 |
+
for name in input_names:
|
283 |
+
input_tensor.append(predictor.get_input_handle(name))
|
284 |
+
else:
|
285 |
+
for name in input_names:
|
286 |
+
input_tensor = predictor.get_input_handle(name)
|
287 |
+
output_tensors = get_output_tensors(args, mode, predictor)
|
288 |
+
return predictor, input_tensor, output_tensors, config
|
289 |
+
|
290 |
+
|
291 |
+
def get_output_tensors(args, mode, predictor):
|
292 |
+
output_names = predictor.get_output_names()
|
293 |
+
output_tensors = []
|
294 |
+
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
|
295 |
+
output_name = 'softmax_0.tmp_0'
|
296 |
+
if output_name in output_names:
|
297 |
+
return [predictor.get_output_handle(output_name)]
|
298 |
+
else:
|
299 |
+
for output_name in output_names:
|
300 |
+
output_tensor = predictor.get_output_handle(output_name)
|
301 |
+
output_tensors.append(output_tensor)
|
302 |
+
else:
|
303 |
+
for output_name in output_names:
|
304 |
+
output_tensor = predictor.get_output_handle(output_name)
|
305 |
+
output_tensors.append(output_tensor)
|
306 |
+
return output_tensors
|
307 |
+
|
308 |
+
|
309 |
+
def get_infer_gpuid():
|
310 |
+
sysstr = platform.system()
|
311 |
+
if sysstr == "Windows":
|
312 |
+
return 0
|
313 |
+
|
314 |
+
if not paddle.fluid.core.is_compiled_with_rocm():
|
315 |
+
cmd = "env | grep CUDA_VISIBLE_DEVICES"
|
316 |
+
else:
|
317 |
+
cmd = "env | grep HIP_VISIBLE_DEVICES"
|
318 |
+
env_cuda = os.popen(cmd).readlines()
|
319 |
+
if len(env_cuda) == 0:
|
320 |
+
return 0
|
321 |
+
else:
|
322 |
+
gpu_id = env_cuda[0].strip().split("=")[1]
|
323 |
+
return int(gpu_id[0])
|
324 |
+
|
325 |
+
|
326 |
+
def draw_e2e_res(dt_boxes, strs, img_path):
|
327 |
+
src_im = cv2.imread(img_path)
|
328 |
+
for box, str in zip(dt_boxes, strs):
|
329 |
+
box = box.astype(np.int32).reshape((-1, 1, 2))
|
330 |
+
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
331 |
+
cv2.putText(
|
332 |
+
src_im,
|
333 |
+
str,
|
334 |
+
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
335 |
+
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
336 |
+
fontScale=0.7,
|
337 |
+
color=(0, 255, 0),
|
338 |
+
thickness=1)
|
339 |
+
return src_im
|
340 |
+
|
341 |
+
|
342 |
+
def draw_text_det_res(dt_boxes, img):
|
343 |
+
for box in dt_boxes:
|
344 |
+
box = np.array(box).astype(np.int32).reshape(-1, 2)
|
345 |
+
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
346 |
+
return img
|
347 |
+
|
348 |
+
|
349 |
+
def resize_img(img, input_size=600):
|
350 |
+
"""
|
351 |
+
resize img and limit the longest side of the image to input_size
|
352 |
+
"""
|
353 |
+
img = np.array(img)
|
354 |
+
im_shape = img.shape
|
355 |
+
im_size_max = np.max(im_shape[0:2])
|
356 |
+
im_scale = float(input_size) / float(im_size_max)
|
357 |
+
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
|
358 |
+
return img
|
359 |
+
|
360 |
+
|
361 |
+
def draw_ocr(image,
|
362 |
+
boxes,
|
363 |
+
txts=None,
|
364 |
+
scores=None,
|
365 |
+
drop_score=0.5,
|
366 |
+
font_path="./doc/fonts/simfang.ttf"):
|
367 |
+
"""
|
368 |
+
Visualize the results of OCR detection and recognition
|
369 |
+
args:
|
370 |
+
image(Image|array): RGB image
|
371 |
+
boxes(list): boxes with shape(N, 4, 2)
|
372 |
+
txts(list): the texts
|
373 |
+
scores(list): txxs corresponding scores
|
374 |
+
drop_score(float): only scores greater than drop_threshold will be visualized
|
375 |
+
font_path: the path of font which is used to draw text
|
376 |
+
return(array):
|
377 |
+
the visualized img
|
378 |
+
"""
|
379 |
+
if scores is None:
|
380 |
+
scores = [1] * len(boxes)
|
381 |
+
box_num = len(boxes)
|
382 |
+
for i in range(box_num):
|
383 |
+
if scores is not None and (scores[i] < drop_score or
|
384 |
+
math.isnan(scores[i])):
|
385 |
+
continue
|
386 |
+
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
|
387 |
+
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
388 |
+
if txts is not None:
|
389 |
+
img = np.array(resize_img(image, input_size=600))
|
390 |
+
txt_img = text_visual(
|
391 |
+
txts,
|
392 |
+
scores,
|
393 |
+
img_h=img.shape[0],
|
394 |
+
img_w=600,
|
395 |
+
threshold=drop_score,
|
396 |
+
font_path=font_path)
|
397 |
+
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
|
398 |
+
return img
|
399 |
+
return image
|
400 |
+
|
401 |
+
|
402 |
+
def draw_ocr_box_txt(image,
|
403 |
+
boxes,
|
404 |
+
txts=None,
|
405 |
+
scores=None,
|
406 |
+
drop_score=0.5,
|
407 |
+
font_path="./doc/fonts/simfang.ttf"):
|
408 |
+
h, w = image.height, image.width
|
409 |
+
img_left = image.copy()
|
410 |
+
img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
|
411 |
+
random.seed(0)
|
412 |
+
|
413 |
+
draw_left = ImageDraw.Draw(img_left)
|
414 |
+
if txts is None or len(txts) != len(boxes):
|
415 |
+
txts = [None] * len(boxes)
|
416 |
+
for idx, (box, txt) in enumerate(zip(boxes, txts)):
|
417 |
+
if scores is not None and scores[idx] < drop_score:
|
418 |
+
continue
|
419 |
+
color = (random.randint(0, 255), random.randint(0, 255),
|
420 |
+
random.randint(0, 255))
|
421 |
+
draw_left.polygon(box, fill=color)
|
422 |
+
img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
|
423 |
+
pts = np.array(box, np.int32).reshape((-1, 1, 2))
|
424 |
+
cv2.polylines(img_right_text, [pts], True, color, 1)
|
425 |
+
img_right = cv2.bitwise_and(img_right, img_right_text)
|
426 |
+
img_left = Image.blend(image, img_left, 0.5)
|
427 |
+
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
|
428 |
+
img_show.paste(img_left, (0, 0, w, h))
|
429 |
+
img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
|
430 |
+
return np.array(img_show)
|
431 |
+
|
432 |
+
|
433 |
+
def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
|
434 |
+
box_height = int(
|
435 |
+
math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2))
|
436 |
+
box_width = int(
|
437 |
+
math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2))
|
438 |
+
|
439 |
+
if box_height > 2 * box_width and box_height > 30:
|
440 |
+
img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255))
|
441 |
+
draw_text = ImageDraw.Draw(img_text)
|
442 |
+
if txt:
|
443 |
+
font = create_font(txt, (box_height, box_width), font_path)
|
444 |
+
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
|
445 |
+
img_text = img_text.transpose(Image.ROTATE_270)
|
446 |
+
else:
|
447 |
+
img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255))
|
448 |
+
draw_text = ImageDraw.Draw(img_text)
|
449 |
+
if txt:
|
450 |
+
font = create_font(txt, (box_width, box_height), font_path)
|
451 |
+
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
|
452 |
+
|
453 |
+
pts1 = np.float32(
|
454 |
+
[[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]])
|
455 |
+
pts2 = np.array(box, dtype=np.float32)
|
456 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
457 |
+
|
458 |
+
img_text = np.array(img_text, dtype=np.uint8)
|
459 |
+
img_right_text = cv2.warpPerspective(
|
460 |
+
img_text,
|
461 |
+
M,
|
462 |
+
img_size,
|
463 |
+
flags=cv2.INTER_NEAREST,
|
464 |
+
borderMode=cv2.BORDER_CONSTANT,
|
465 |
+
borderValue=(255, 255, 255))
|
466 |
+
return img_right_text
|
467 |
+
|
468 |
+
|
469 |
+
def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
|
470 |
+
font_size = int(sz[1] * 0.99)
|
471 |
+
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
472 |
+
length = font.getsize(txt)[0]
|
473 |
+
if length > sz[0]:
|
474 |
+
font_size = int(font_size * sz[0] / length)
|
475 |
+
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
476 |
+
return font
|
477 |
+
|
478 |
+
|
479 |
+
def str_count(s):
|
480 |
+
"""
|
481 |
+
Count the number of Chinese characters,
|
482 |
+
a single English character and a single number
|
483 |
+
equal to half the length of Chinese characters.
|
484 |
+
args:
|
485 |
+
s(string): the input of string
|
486 |
+
return(int):
|
487 |
+
the number of Chinese characters
|
488 |
+
"""
|
489 |
+
import string
|
490 |
+
count_zh = count_pu = 0
|
491 |
+
s_len = len(s)
|
492 |
+
en_dg_count = 0
|
493 |
+
for c in s:
|
494 |
+
if c in string.ascii_letters or c.isdigit() or c.isspace():
|
495 |
+
en_dg_count += 1
|
496 |
+
elif c.isalpha():
|
497 |
+
count_zh += 1
|
498 |
+
else:
|
499 |
+
count_pu += 1
|
500 |
+
return s_len - math.ceil(en_dg_count / 2)
|
501 |
+
|
502 |
+
|
503 |
+
def text_visual(texts,
|
504 |
+
scores,
|
505 |
+
img_h=400,
|
506 |
+
img_w=600,
|
507 |
+
threshold=0.,
|
508 |
+
font_path="./doc/simfang.ttf"):
|
509 |
+
"""
|
510 |
+
create new blank img and draw txt on it
|
511 |
+
args:
|
512 |
+
texts(list): the text will be draw
|
513 |
+
scores(list|None): corresponding score of each txt
|
514 |
+
img_h(int): the height of blank img
|
515 |
+
img_w(int): the width of blank img
|
516 |
+
font_path: the path of font which is used to draw text
|
517 |
+
return(array):
|
518 |
+
"""
|
519 |
+
if scores is not None:
|
520 |
+
assert len(texts) == len(
|
521 |
+
scores), "The number of txts and corresponding scores must match"
|
522 |
+
|
523 |
+
def create_blank_img():
|
524 |
+
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
|
525 |
+
blank_img[:, img_w - 1:] = 0
|
526 |
+
blank_img = Image.fromarray(blank_img).convert("RGB")
|
527 |
+
draw_txt = ImageDraw.Draw(blank_img)
|
528 |
+
return blank_img, draw_txt
|
529 |
+
|
530 |
+
blank_img, draw_txt = create_blank_img()
|
531 |
+
|
532 |
+
font_size = 20
|
533 |
+
txt_color = (0, 0, 0)
|
534 |
+
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
535 |
+
|
536 |
+
gap = font_size + 5
|
537 |
+
txt_img_list = []
|
538 |
+
count, index = 1, 0
|
539 |
+
for idx, txt in enumerate(texts):
|
540 |
+
index += 1
|
541 |
+
if scores[idx] < threshold or math.isnan(scores[idx]):
|
542 |
+
index -= 1
|
543 |
+
continue
|
544 |
+
first_line = True
|
545 |
+
while str_count(txt) >= img_w // font_size - 4:
|
546 |
+
tmp = txt
|
547 |
+
txt = tmp[:img_w // font_size - 4]
|
548 |
+
if first_line:
|
549 |
+
new_txt = str(index) + ': ' + txt
|
550 |
+
first_line = False
|
551 |
+
else:
|
552 |
+
new_txt = ' ' + txt
|
553 |
+
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
554 |
+
txt = tmp[img_w // font_size - 4:]
|
555 |
+
if count >= img_h // gap - 1:
|
556 |
+
txt_img_list.append(np.array(blank_img))
|
557 |
+
blank_img, draw_txt = create_blank_img()
|
558 |
+
count = 0
|
559 |
+
count += 1
|
560 |
+
if first_line:
|
561 |
+
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
|
562 |
+
else:
|
563 |
+
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
|
564 |
+
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
565 |
+
# whether add new blank img or not
|
566 |
+
if count >= img_h // gap - 1 and idx + 1 < len(texts):
|
567 |
+
txt_img_list.append(np.array(blank_img))
|
568 |
+
blank_img, draw_txt = create_blank_img()
|
569 |
+
count = 0
|
570 |
+
count += 1
|
571 |
+
txt_img_list.append(np.array(blank_img))
|
572 |
+
if len(txt_img_list) == 1:
|
573 |
+
blank_img = np.array(txt_img_list[0])
|
574 |
+
else:
|
575 |
+
blank_img = np.concatenate(txt_img_list, axis=1)
|
576 |
+
return np.array(blank_img)
|
577 |
+
|
578 |
+
|
579 |
+
def base64_to_cv2(b64str):
|
580 |
+
import base64
|
581 |
+
data = base64.b64decode(b64str.encode('utf8'))
|
582 |
+
data = np.frombuffer(data, np.uint8)
|
583 |
+
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
584 |
+
return data
|
585 |
+
|
586 |
+
|
587 |
+
def draw_boxes(image, boxes, scores=None, drop_score=0.5):
|
588 |
+
if scores is None:
|
589 |
+
scores = [1] * len(boxes)
|
590 |
+
for (box, score) in zip(boxes, scores):
|
591 |
+
if score < drop_score:
|
592 |
+
continue
|
593 |
+
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
|
594 |
+
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
595 |
+
return image
|
596 |
+
|
597 |
+
|
598 |
+
def get_rotate_crop_image(img, points):
|
599 |
+
'''
|
600 |
+
img_height, img_width = img.shape[0:2]
|
601 |
+
left = int(np.min(points[:, 0]))
|
602 |
+
right = int(np.max(points[:, 0]))
|
603 |
+
top = int(np.min(points[:, 1]))
|
604 |
+
bottom = int(np.max(points[:, 1]))
|
605 |
+
img_crop = img[top:bottom, left:right, :].copy()
|
606 |
+
points[:, 0] = points[:, 0] - left
|
607 |
+
points[:, 1] = points[:, 1] - top
|
608 |
+
'''
|
609 |
+
assert len(points) == 4, "shape of points must be 4*2"
|
610 |
+
img_crop_width = int(
|
611 |
+
max(
|
612 |
+
np.linalg.norm(points[0] - points[1]),
|
613 |
+
np.linalg.norm(points[2] - points[3])))
|
614 |
+
img_crop_height = int(
|
615 |
+
max(
|
616 |
+
np.linalg.norm(points[0] - points[3]),
|
617 |
+
np.linalg.norm(points[1] - points[2])))
|
618 |
+
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
619 |
+
[img_crop_width, img_crop_height],
|
620 |
+
[0, img_crop_height]])
|
621 |
+
M = cv2.getPerspectiveTransform(points, pts_std)
|
622 |
+
dst_img = cv2.warpPerspective(
|
623 |
+
img,
|
624 |
+
M, (img_crop_width, img_crop_height),
|
625 |
+
borderMode=cv2.BORDER_REPLICATE,
|
626 |
+
flags=cv2.INTER_CUBIC)
|
627 |
+
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
628 |
+
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
629 |
+
dst_img = np.rot90(dst_img)
|
630 |
+
return dst_img
|
631 |
+
|
632 |
+
|
633 |
+
def get_minarea_rect_crop(img, points):
|
634 |
+
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
|
635 |
+
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
636 |
+
|
637 |
+
index_a, index_b, index_c, index_d = 0, 1, 2, 3
|
638 |
+
if points[1][1] > points[0][1]:
|
639 |
+
index_a = 0
|
640 |
+
index_d = 1
|
641 |
+
else:
|
642 |
+
index_a = 1
|
643 |
+
index_d = 0
|
644 |
+
if points[3][1] > points[2][1]:
|
645 |
+
index_b = 2
|
646 |
+
index_c = 3
|
647 |
+
else:
|
648 |
+
index_b = 3
|
649 |
+
index_c = 2
|
650 |
+
|
651 |
+
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
|
652 |
+
crop_img = get_rotate_crop_image(img, np.array(box))
|
653 |
+
return crop_img
|
654 |
+
|
655 |
+
|
656 |
+
def check_gpu(use_gpu):
|
657 |
+
if use_gpu and not paddle.is_compiled_with_cuda():
|
658 |
+
use_gpu = False
|
659 |
+
return use_gpu
|
660 |
+
|
661 |
+
|
662 |
+
if __name__ == '__main__':
|
663 |
+
pass
|
tools/infer_cls.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append(__dir__)
|
26 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
27 |
+
|
28 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
29 |
+
|
30 |
+
import paddle
|
31 |
+
|
32 |
+
from ppocr.data import create_operators, transform
|
33 |
+
from ppocr.modeling.architectures import build_model
|
34 |
+
from ppocr.postprocess import build_post_process
|
35 |
+
from ppocr.utils.save_load import load_model
|
36 |
+
from ppocr.utils.utility import get_image_file_list
|
37 |
+
import tools.program as program
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
global_config = config['Global']
|
42 |
+
|
43 |
+
# build post process
|
44 |
+
post_process_class = build_post_process(config['PostProcess'],
|
45 |
+
global_config)
|
46 |
+
|
47 |
+
# build model
|
48 |
+
model = build_model(config['Architecture'])
|
49 |
+
|
50 |
+
load_model(config, model)
|
51 |
+
|
52 |
+
# create data ops
|
53 |
+
transforms = []
|
54 |
+
for op in config['Eval']['dataset']['transforms']:
|
55 |
+
op_name = list(op)[0]
|
56 |
+
if 'Label' in op_name:
|
57 |
+
continue
|
58 |
+
elif op_name == 'KeepKeys':
|
59 |
+
op[op_name]['keep_keys'] = ['image']
|
60 |
+
elif op_name == "SSLRotateResize":
|
61 |
+
op[op_name]["mode"] = "test"
|
62 |
+
transforms.append(op)
|
63 |
+
global_config['infer_mode'] = True
|
64 |
+
ops = create_operators(transforms, global_config)
|
65 |
+
|
66 |
+
model.eval()
|
67 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
68 |
+
logger.info("infer_img: {}".format(file))
|
69 |
+
with open(file, 'rb') as f:
|
70 |
+
img = f.read()
|
71 |
+
data = {'image': img}
|
72 |
+
batch = transform(data, ops)
|
73 |
+
|
74 |
+
images = np.expand_dims(batch[0], axis=0)
|
75 |
+
images = paddle.to_tensor(images)
|
76 |
+
preds = model(images)
|
77 |
+
post_result = post_process_class(preds)
|
78 |
+
for rec_result in post_result:
|
79 |
+
logger.info('\t result: {}'.format(rec_result))
|
80 |
+
logger.info("success!")
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
config, device, logger, vdl_writer = program.preprocess()
|
85 |
+
main()
|
tools/infer_det.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append(__dir__)
|
26 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
27 |
+
|
28 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
29 |
+
|
30 |
+
import cv2
|
31 |
+
import json
|
32 |
+
import paddle
|
33 |
+
|
34 |
+
from ppocr.data import create_operators, transform
|
35 |
+
from ppocr.modeling.architectures import build_model
|
36 |
+
from ppocr.postprocess import build_post_process
|
37 |
+
from ppocr.utils.save_load import load_model
|
38 |
+
from ppocr.utils.utility import get_image_file_list
|
39 |
+
import tools.program as program
|
40 |
+
|
41 |
+
|
42 |
+
def draw_det_res(dt_boxes, config, img, img_name, save_path):
|
43 |
+
if len(dt_boxes) > 0:
|
44 |
+
import cv2
|
45 |
+
src_im = img
|
46 |
+
for box in dt_boxes:
|
47 |
+
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
|
48 |
+
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
49 |
+
if not os.path.exists(save_path):
|
50 |
+
os.makedirs(save_path)
|
51 |
+
save_path = os.path.join(save_path, os.path.basename(img_name))
|
52 |
+
cv2.imwrite(save_path, src_im)
|
53 |
+
logger.info("The detected Image saved in {}".format(save_path))
|
54 |
+
|
55 |
+
|
56 |
+
@paddle.no_grad()
|
57 |
+
def main():
|
58 |
+
global_config = config['Global']
|
59 |
+
|
60 |
+
# build model
|
61 |
+
model = build_model(config['Architecture'])
|
62 |
+
|
63 |
+
load_model(config, model)
|
64 |
+
# build post process
|
65 |
+
post_process_class = build_post_process(config['PostProcess'])
|
66 |
+
|
67 |
+
# create data ops
|
68 |
+
transforms = []
|
69 |
+
for op in config['Eval']['dataset']['transforms']:
|
70 |
+
op_name = list(op)[0]
|
71 |
+
if 'Label' in op_name:
|
72 |
+
continue
|
73 |
+
elif op_name == 'KeepKeys':
|
74 |
+
op[op_name]['keep_keys'] = ['image', 'shape']
|
75 |
+
transforms.append(op)
|
76 |
+
|
77 |
+
ops = create_operators(transforms, global_config)
|
78 |
+
|
79 |
+
save_res_path = config['Global']['save_res_path']
|
80 |
+
if not os.path.exists(os.path.dirname(save_res_path)):
|
81 |
+
os.makedirs(os.path.dirname(save_res_path))
|
82 |
+
|
83 |
+
model.eval()
|
84 |
+
with open(save_res_path, "wb") as fout:
|
85 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
86 |
+
logger.info("infer_img: {}".format(file))
|
87 |
+
with open(file, 'rb') as f:
|
88 |
+
img = f.read()
|
89 |
+
data = {'image': img}
|
90 |
+
batch = transform(data, ops)
|
91 |
+
|
92 |
+
images = np.expand_dims(batch[0], axis=0)
|
93 |
+
shape_list = np.expand_dims(batch[1], axis=0)
|
94 |
+
images = paddle.to_tensor(images)
|
95 |
+
preds = model(images)
|
96 |
+
post_result = post_process_class(preds, shape_list)
|
97 |
+
|
98 |
+
src_img = cv2.imread(file)
|
99 |
+
|
100 |
+
dt_boxes_json = []
|
101 |
+
# parser boxes if post_result is dict
|
102 |
+
if isinstance(post_result, dict):
|
103 |
+
det_box_json = {}
|
104 |
+
for k in post_result.keys():
|
105 |
+
boxes = post_result[k][0]['points']
|
106 |
+
dt_boxes_list = []
|
107 |
+
for box in boxes:
|
108 |
+
tmp_json = {"transcription": ""}
|
109 |
+
tmp_json['points'] = np.array(box).tolist()
|
110 |
+
dt_boxes_list.append(tmp_json)
|
111 |
+
det_box_json[k] = dt_boxes_list
|
112 |
+
save_det_path = os.path.dirname(config['Global'][
|
113 |
+
'save_res_path']) + "/det_results_{}/".format(k)
|
114 |
+
draw_det_res(boxes, config, src_img, file, save_det_path)
|
115 |
+
else:
|
116 |
+
boxes = post_result[0]['points']
|
117 |
+
dt_boxes_json = []
|
118 |
+
# write result
|
119 |
+
for box in boxes:
|
120 |
+
tmp_json = {"transcription": ""}
|
121 |
+
tmp_json['points'] = np.array(box).tolist()
|
122 |
+
dt_boxes_json.append(tmp_json)
|
123 |
+
save_det_path = os.path.dirname(config['Global'][
|
124 |
+
'save_res_path']) + "/det_results/"
|
125 |
+
draw_det_res(boxes, config, src_img, file, save_det_path)
|
126 |
+
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
|
127 |
+
fout.write(otstr.encode())
|
128 |
+
|
129 |
+
logger.info("success!")
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
config, device, logger, vdl_writer = program.preprocess()
|
134 |
+
main()
|
tools/infer_e2e.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append(__dir__)
|
26 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
27 |
+
|
28 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
29 |
+
|
30 |
+
import cv2
|
31 |
+
import json
|
32 |
+
import paddle
|
33 |
+
|
34 |
+
from ppocr.data import create_operators, transform
|
35 |
+
from ppocr.modeling.architectures import build_model
|
36 |
+
from ppocr.postprocess import build_post_process
|
37 |
+
from ppocr.utils.save_load import load_model
|
38 |
+
from ppocr.utils.utility import get_image_file_list
|
39 |
+
import tools.program as program
|
40 |
+
from PIL import Image, ImageDraw, ImageFont
|
41 |
+
import math
|
42 |
+
|
43 |
+
|
44 |
+
def draw_e2e_res_for_chinese(image,
|
45 |
+
boxes,
|
46 |
+
txts,
|
47 |
+
config,
|
48 |
+
img_name,
|
49 |
+
font_path="./doc/simfang.ttf"):
|
50 |
+
h, w = image.height, image.width
|
51 |
+
img_left = image.copy()
|
52 |
+
img_right = Image.new('RGB', (w, h), (255, 255, 255))
|
53 |
+
|
54 |
+
import random
|
55 |
+
|
56 |
+
random.seed(0)
|
57 |
+
draw_left = ImageDraw.Draw(img_left)
|
58 |
+
draw_right = ImageDraw.Draw(img_right)
|
59 |
+
for idx, (box, txt) in enumerate(zip(boxes, txts)):
|
60 |
+
box = np.array(box)
|
61 |
+
box = [tuple(x) for x in box]
|
62 |
+
color = (random.randint(0, 255), random.randint(0, 255),
|
63 |
+
random.randint(0, 255))
|
64 |
+
draw_left.polygon(box, fill=color)
|
65 |
+
draw_right.polygon(box, outline=color)
|
66 |
+
font = ImageFont.truetype(font_path, 15, encoding="utf-8")
|
67 |
+
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
68 |
+
img_left = Image.blend(image, img_left, 0.5)
|
69 |
+
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
|
70 |
+
img_show.paste(img_left, (0, 0, w, h))
|
71 |
+
img_show.paste(img_right, (w, 0, w * 2, h))
|
72 |
+
|
73 |
+
save_e2e_path = os.path.dirname(config['Global'][
|
74 |
+
'save_res_path']) + "/e2e_results/"
|
75 |
+
if not os.path.exists(save_e2e_path):
|
76 |
+
os.makedirs(save_e2e_path)
|
77 |
+
save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
|
78 |
+
cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
|
79 |
+
logger.info("The e2e Image saved in {}".format(save_path))
|
80 |
+
|
81 |
+
|
82 |
+
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
|
83 |
+
if len(dt_boxes) > 0:
|
84 |
+
src_im = img
|
85 |
+
for box, str in zip(dt_boxes, strs):
|
86 |
+
box = box.astype(np.int32).reshape((-1, 1, 2))
|
87 |
+
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
88 |
+
cv2.putText(
|
89 |
+
src_im,
|
90 |
+
str,
|
91 |
+
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
92 |
+
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
93 |
+
fontScale=0.7,
|
94 |
+
color=(0, 255, 0),
|
95 |
+
thickness=1)
|
96 |
+
save_det_path = os.path.dirname(config['Global'][
|
97 |
+
'save_res_path']) + "/e2e_results/"
|
98 |
+
if not os.path.exists(save_det_path):
|
99 |
+
os.makedirs(save_det_path)
|
100 |
+
save_path = os.path.join(save_det_path, os.path.basename(img_name))
|
101 |
+
cv2.imwrite(save_path, src_im)
|
102 |
+
logger.info("The e2e Image saved in {}".format(save_path))
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
global_config = config['Global']
|
107 |
+
|
108 |
+
# build model
|
109 |
+
model = build_model(config['Architecture'])
|
110 |
+
|
111 |
+
load_model(config, model)
|
112 |
+
|
113 |
+
# build post process
|
114 |
+
post_process_class = build_post_process(config['PostProcess'],
|
115 |
+
global_config)
|
116 |
+
|
117 |
+
# create data ops
|
118 |
+
transforms = []
|
119 |
+
for op in config['Eval']['dataset']['transforms']:
|
120 |
+
op_name = list(op)[0]
|
121 |
+
if 'Label' in op_name:
|
122 |
+
continue
|
123 |
+
elif op_name == 'KeepKeys':
|
124 |
+
op[op_name]['keep_keys'] = ['image', 'shape']
|
125 |
+
transforms.append(op)
|
126 |
+
|
127 |
+
ops = create_operators(transforms, global_config)
|
128 |
+
|
129 |
+
save_res_path = config['Global']['save_res_path']
|
130 |
+
if not os.path.exists(os.path.dirname(save_res_path)):
|
131 |
+
os.makedirs(os.path.dirname(save_res_path))
|
132 |
+
|
133 |
+
model.eval()
|
134 |
+
with open(save_res_path, "wb") as fout:
|
135 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
136 |
+
logger.info("infer_img: {}".format(file))
|
137 |
+
with open(file, 'rb') as f:
|
138 |
+
img = f.read()
|
139 |
+
data = {'image': img}
|
140 |
+
batch = transform(data, ops)
|
141 |
+
images = np.expand_dims(batch[0], axis=0)
|
142 |
+
shape_list = np.expand_dims(batch[1], axis=0)
|
143 |
+
images = paddle.to_tensor(images)
|
144 |
+
preds = model(images)
|
145 |
+
post_result = post_process_class(preds, shape_list)
|
146 |
+
points, strs = post_result['points'], post_result['texts']
|
147 |
+
# write result
|
148 |
+
dt_boxes_json = []
|
149 |
+
for poly, str in zip(points, strs):
|
150 |
+
tmp_json = {"transcription": str}
|
151 |
+
tmp_json['points'] = poly.tolist()
|
152 |
+
dt_boxes_json.append(tmp_json)
|
153 |
+
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
|
154 |
+
fout.write(otstr.encode())
|
155 |
+
src_img = cv2.imread(file)
|
156 |
+
if global_config['infer_visual_type'] == 'EN':
|
157 |
+
draw_e2e_res(points, strs, config, src_img, file)
|
158 |
+
elif global_config['infer_visual_type'] == 'CN':
|
159 |
+
src_img = Image.fromarray(
|
160 |
+
cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
|
161 |
+
draw_e2e_res_for_chinese(
|
162 |
+
src_img,
|
163 |
+
points,
|
164 |
+
strs,
|
165 |
+
config,
|
166 |
+
file,
|
167 |
+
font_path="./doc/fonts/simfang.ttf")
|
168 |
+
|
169 |
+
logger.info("success!")
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
config, device, logger, vdl_writer = program.preprocess()
|
174 |
+
main()
|
tools/infer_kie.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
|
25 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append(__dir__)
|
27 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
28 |
+
|
29 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
30 |
+
|
31 |
+
import cv2
|
32 |
+
import paddle
|
33 |
+
|
34 |
+
from ppocr.data import create_operators, transform
|
35 |
+
from ppocr.modeling.architectures import build_model
|
36 |
+
from ppocr.utils.save_load import load_model
|
37 |
+
import tools.program as program
|
38 |
+
import time
|
39 |
+
|
40 |
+
|
41 |
+
def read_class_list(filepath):
|
42 |
+
ret = {}
|
43 |
+
with open(filepath, "r") as f:
|
44 |
+
lines = f.readlines()
|
45 |
+
for idx, line in enumerate(lines):
|
46 |
+
ret[idx] = line.strip("\n")
|
47 |
+
return ret
|
48 |
+
|
49 |
+
|
50 |
+
def draw_kie_result(batch, node, idx_to_cls, count):
|
51 |
+
img = batch[6].copy()
|
52 |
+
boxes = batch[7]
|
53 |
+
h, w = img.shape[:2]
|
54 |
+
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
|
55 |
+
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
56 |
+
node_pred_label = max_idx.numpy().tolist()
|
57 |
+
node_pred_score = max_value.numpy().tolist()
|
58 |
+
|
59 |
+
for i, box in enumerate(boxes):
|
60 |
+
if i >= len(node_pred_label):
|
61 |
+
break
|
62 |
+
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
63 |
+
[box[0], box[3]]]
|
64 |
+
Pts = np.array([new_box], np.int32)
|
65 |
+
cv2.polylines(
|
66 |
+
img, [Pts.reshape((-1, 1, 2))],
|
67 |
+
True,
|
68 |
+
color=(255, 255, 0),
|
69 |
+
thickness=1)
|
70 |
+
x_min = int(min([point[0] for point in new_box]))
|
71 |
+
y_min = int(min([point[1] for point in new_box]))
|
72 |
+
|
73 |
+
pred_label = node_pred_label[i]
|
74 |
+
if pred_label in idx_to_cls:
|
75 |
+
pred_label = idx_to_cls[pred_label]
|
76 |
+
pred_score = '{:.2f}'.format(node_pred_score[i])
|
77 |
+
text = pred_label + '(' + pred_score + ')'
|
78 |
+
cv2.putText(pred_img, text, (x_min * 2, y_min),
|
79 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
80 |
+
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
81 |
+
vis_img[:, :w] = img
|
82 |
+
vis_img[:, w:] = pred_img
|
83 |
+
save_kie_path = os.path.dirname(config['Global'][
|
84 |
+
'save_res_path']) + "/kie_results/"
|
85 |
+
if not os.path.exists(save_kie_path):
|
86 |
+
os.makedirs(save_kie_path)
|
87 |
+
save_path = os.path.join(save_kie_path, str(count) + ".png")
|
88 |
+
cv2.imwrite(save_path, vis_img)
|
89 |
+
logger.info("The Kie Image saved in {}".format(save_path))
|
90 |
+
|
91 |
+
def write_kie_result(fout, node, data):
|
92 |
+
"""
|
93 |
+
Write infer result to output file, sorted by the predict label of each line.
|
94 |
+
The format keeps the same as the input with additional score attribute.
|
95 |
+
"""
|
96 |
+
import json
|
97 |
+
label = data['label']
|
98 |
+
annotations = json.loads(label)
|
99 |
+
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
100 |
+
node_pred_label = max_idx.numpy().tolist()
|
101 |
+
node_pred_score = max_value.numpy().tolist()
|
102 |
+
res = []
|
103 |
+
for i, label in enumerate(node_pred_label):
|
104 |
+
pred_score = '{:.2f}'.format(node_pred_score[i])
|
105 |
+
pred_res = {
|
106 |
+
'label': label,
|
107 |
+
'transcription': annotations[i]['transcription'],
|
108 |
+
'score': pred_score,
|
109 |
+
'points': annotations[i]['points'],
|
110 |
+
}
|
111 |
+
res.append(pred_res)
|
112 |
+
res.sort(key=lambda x: x['label'])
|
113 |
+
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
|
114 |
+
|
115 |
+
def main():
|
116 |
+
global_config = config['Global']
|
117 |
+
|
118 |
+
# build model
|
119 |
+
model = build_model(config['Architecture'])
|
120 |
+
load_model(config, model)
|
121 |
+
|
122 |
+
# create data ops
|
123 |
+
transforms = []
|
124 |
+
for op in config['Eval']['dataset']['transforms']:
|
125 |
+
transforms.append(op)
|
126 |
+
|
127 |
+
data_dir = config['Eval']['dataset']['data_dir']
|
128 |
+
|
129 |
+
ops = create_operators(transforms, global_config)
|
130 |
+
|
131 |
+
save_res_path = config['Global']['save_res_path']
|
132 |
+
class_path = config['Global']['class_path']
|
133 |
+
idx_to_cls = read_class_list(class_path)
|
134 |
+
os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
|
135 |
+
|
136 |
+
model.eval()
|
137 |
+
|
138 |
+
warmup_times = 0
|
139 |
+
count_t = []
|
140 |
+
with open(save_res_path, "w") as fout:
|
141 |
+
with open(config['Global']['infer_img'], "rb") as f:
|
142 |
+
lines = f.readlines()
|
143 |
+
for index, data_line in enumerate(lines):
|
144 |
+
if index == 10:
|
145 |
+
warmup_t = time.time()
|
146 |
+
data_line = data_line.decode('utf-8')
|
147 |
+
substr = data_line.strip("\n").split("\t")
|
148 |
+
img_path, label = data_dir + "/" + substr[0], substr[1]
|
149 |
+
data = {'img_path': img_path, 'label': label}
|
150 |
+
with open(data['img_path'], 'rb') as f:
|
151 |
+
img = f.read()
|
152 |
+
data['image'] = img
|
153 |
+
st = time.time()
|
154 |
+
batch = transform(data, ops)
|
155 |
+
batch_pred = [0] * len(batch)
|
156 |
+
for i in range(len(batch)):
|
157 |
+
batch_pred[i] = paddle.to_tensor(
|
158 |
+
np.expand_dims(
|
159 |
+
batch[i], axis=0))
|
160 |
+
st = time.time()
|
161 |
+
node, edge = model(batch_pred)
|
162 |
+
node = F.softmax(node, -1)
|
163 |
+
count_t.append(time.time() - st)
|
164 |
+
draw_kie_result(batch, node, idx_to_cls, index)
|
165 |
+
write_kie_result(fout, node, data)
|
166 |
+
fout.close()
|
167 |
+
logger.info("success!")
|
168 |
+
logger.info("It took {} s for predict {} images.".format(
|
169 |
+
np.sum(count_t), len(count_t)))
|
170 |
+
ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
|
171 |
+
logger.info("The ips is {} images/s".format(ips))
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == '__main__':
|
175 |
+
config, device, logger, vdl_writer = program.preprocess()
|
176 |
+
main()
|
tools/infer_kie_token_ser.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append(__dir__)
|
26 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
27 |
+
|
28 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
29 |
+
import cv2
|
30 |
+
import json
|
31 |
+
import paddle
|
32 |
+
|
33 |
+
from ppocr.data import create_operators, transform
|
34 |
+
from ppocr.modeling.architectures import build_model
|
35 |
+
from ppocr.postprocess import build_post_process
|
36 |
+
from ppocr.utils.save_load import load_model
|
37 |
+
from ppocr.utils.visual import draw_ser_results
|
38 |
+
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
|
39 |
+
import tools.program as program
|
40 |
+
|
41 |
+
|
42 |
+
def to_tensor(data):
|
43 |
+
import numbers
|
44 |
+
from collections import defaultdict
|
45 |
+
data_dict = defaultdict(list)
|
46 |
+
to_tensor_idxs = []
|
47 |
+
|
48 |
+
for idx, v in enumerate(data):
|
49 |
+
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
50 |
+
if idx not in to_tensor_idxs:
|
51 |
+
to_tensor_idxs.append(idx)
|
52 |
+
data_dict[idx].append(v)
|
53 |
+
for idx in to_tensor_idxs:
|
54 |
+
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
55 |
+
return list(data_dict.values())
|
56 |
+
|
57 |
+
|
58 |
+
class SerPredictor(object):
|
59 |
+
def __init__(self, config):
|
60 |
+
global_config = config['Global']
|
61 |
+
self.algorithm = config['Architecture']["algorithm"]
|
62 |
+
|
63 |
+
# build post process
|
64 |
+
self.post_process_class = build_post_process(config['PostProcess'],
|
65 |
+
global_config)
|
66 |
+
|
67 |
+
# build model
|
68 |
+
self.model = build_model(config['Architecture'])
|
69 |
+
|
70 |
+
load_model(
|
71 |
+
config, self.model, model_type=config['Architecture']["model_type"])
|
72 |
+
|
73 |
+
from paddleocr import PaddleOCR
|
74 |
+
|
75 |
+
self.ocr_engine = PaddleOCR(
|
76 |
+
use_angle_cls=False,
|
77 |
+
show_log=False,
|
78 |
+
rec_model_dir=global_config.get("kie_rec_model_dir", None),
|
79 |
+
det_model_dir=global_config.get("kie_det_model_dir", None),
|
80 |
+
use_gpu=global_config['use_gpu'])
|
81 |
+
|
82 |
+
# create data ops
|
83 |
+
transforms = []
|
84 |
+
for op in config['Eval']['dataset']['transforms']:
|
85 |
+
op_name = list(op)[0]
|
86 |
+
if 'Label' in op_name:
|
87 |
+
op[op_name]['ocr_engine'] = self.ocr_engine
|
88 |
+
elif op_name == 'KeepKeys':
|
89 |
+
op[op_name]['keep_keys'] = [
|
90 |
+
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
|
91 |
+
'image', 'labels', 'segment_offset_id', 'ocr_info',
|
92 |
+
'entities'
|
93 |
+
]
|
94 |
+
|
95 |
+
transforms.append(op)
|
96 |
+
if config["Global"].get("infer_mode", None) is None:
|
97 |
+
global_config['infer_mode'] = True
|
98 |
+
self.ops = create_operators(config['Eval']['dataset']['transforms'],
|
99 |
+
global_config)
|
100 |
+
self.model.eval()
|
101 |
+
|
102 |
+
def __call__(self, data):
|
103 |
+
with open(data["img_path"], 'rb') as f:
|
104 |
+
img = f.read()
|
105 |
+
data["image"] = img
|
106 |
+
batch = transform(data, self.ops)
|
107 |
+
batch = to_tensor(batch)
|
108 |
+
preds = self.model(batch)
|
109 |
+
|
110 |
+
post_result = self.post_process_class(
|
111 |
+
preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
|
112 |
+
return post_result, batch
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
config, device, logger, vdl_writer = program.preprocess()
|
117 |
+
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
|
118 |
+
|
119 |
+
ser_engine = SerPredictor(config)
|
120 |
+
|
121 |
+
if config["Global"].get("infer_mode", None) is False:
|
122 |
+
data_dir = config['Eval']['dataset']['data_dir']
|
123 |
+
with open(config['Global']['infer_img'], "rb") as f:
|
124 |
+
infer_imgs = f.readlines()
|
125 |
+
else:
|
126 |
+
infer_imgs = get_image_file_list(config['Global']['infer_img'])
|
127 |
+
|
128 |
+
with open(
|
129 |
+
os.path.join(config['Global']['save_res_path'],
|
130 |
+
"infer_results.txt"),
|
131 |
+
"w",
|
132 |
+
encoding='utf-8') as fout:
|
133 |
+
for idx, info in enumerate(infer_imgs):
|
134 |
+
if config["Global"].get("infer_mode", None) is False:
|
135 |
+
data_line = info.decode('utf-8')
|
136 |
+
substr = data_line.strip("\n").split("\t")
|
137 |
+
img_path = os.path.join(data_dir, substr[0])
|
138 |
+
data = {'img_path': img_path, 'label': substr[1]}
|
139 |
+
else:
|
140 |
+
img_path = info
|
141 |
+
data = {'img_path': img_path}
|
142 |
+
|
143 |
+
save_img_path = os.path.join(
|
144 |
+
config['Global']['save_res_path'],
|
145 |
+
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
|
146 |
+
|
147 |
+
result, _ = ser_engine(data)
|
148 |
+
result = result[0]
|
149 |
+
fout.write(img_path + "\t" + json.dumps(
|
150 |
+
{
|
151 |
+
"ocr_info": result,
|
152 |
+
}, ensure_ascii=False) + "\n")
|
153 |
+
img_res = draw_ser_results(img_path, result)
|
154 |
+
cv2.imwrite(save_img_path, img_res)
|
155 |
+
|
156 |
+
logger.info("process: [{}/{}], save result to {}".format(
|
157 |
+
idx, len(infer_imgs), save_img_path))
|
tools/infer_kie_token_ser_re.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
|
24 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append(__dir__)
|
26 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
27 |
+
|
28 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
29 |
+
import cv2
|
30 |
+
import json
|
31 |
+
import paddle
|
32 |
+
import paddle.distributed as dist
|
33 |
+
|
34 |
+
from ppocr.data import create_operators, transform
|
35 |
+
from ppocr.modeling.architectures import build_model
|
36 |
+
from ppocr.postprocess import build_post_process
|
37 |
+
from ppocr.utils.save_load import load_model
|
38 |
+
from ppocr.utils.visual import draw_re_results
|
39 |
+
from ppocr.utils.logging import get_logger
|
40 |
+
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
|
41 |
+
from tools.program import ArgsParser, load_config, merge_config
|
42 |
+
from tools.infer_kie_token_ser import SerPredictor
|
43 |
+
|
44 |
+
|
45 |
+
class ReArgsParser(ArgsParser):
|
46 |
+
def __init__(self):
|
47 |
+
super(ReArgsParser, self).__init__()
|
48 |
+
self.add_argument(
|
49 |
+
"-c_ser", "--config_ser", help="ser configuration file to use")
|
50 |
+
self.add_argument(
|
51 |
+
"-o_ser",
|
52 |
+
"--opt_ser",
|
53 |
+
nargs='+',
|
54 |
+
help="set ser configuration options ")
|
55 |
+
|
56 |
+
def parse_args(self, argv=None):
|
57 |
+
args = super(ReArgsParser, self).parse_args(argv)
|
58 |
+
assert args.config_ser is not None, \
|
59 |
+
"Please specify --config_ser=ser_configure_file_path."
|
60 |
+
args.opt_ser = self._parse_opt(args.opt_ser)
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
def make_input(ser_inputs, ser_results):
|
65 |
+
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
|
66 |
+
batch_size, max_seq_len = ser_inputs[0].shape[:2]
|
67 |
+
entities = ser_inputs[8][0]
|
68 |
+
ser_results = ser_results[0]
|
69 |
+
assert len(entities) == len(ser_results)
|
70 |
+
|
71 |
+
# entities
|
72 |
+
start = []
|
73 |
+
end = []
|
74 |
+
label = []
|
75 |
+
entity_idx_dict = {}
|
76 |
+
for i, (res, entity) in enumerate(zip(ser_results, entities)):
|
77 |
+
if res['pred'] == 'O':
|
78 |
+
continue
|
79 |
+
entity_idx_dict[len(start)] = i
|
80 |
+
start.append(entity['start'])
|
81 |
+
end.append(entity['end'])
|
82 |
+
label.append(entities_labels[res['pred']])
|
83 |
+
|
84 |
+
entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
|
85 |
+
entities[0, 0] = len(start)
|
86 |
+
entities[1:len(start) + 1, 0] = start
|
87 |
+
entities[0, 1] = len(end)
|
88 |
+
entities[1:len(end) + 1, 1] = end
|
89 |
+
entities[0, 2] = len(label)
|
90 |
+
entities[1:len(label) + 1, 2] = label
|
91 |
+
|
92 |
+
# relations
|
93 |
+
head = []
|
94 |
+
tail = []
|
95 |
+
for i in range(len(label)):
|
96 |
+
for j in range(len(label)):
|
97 |
+
if label[i] == 1 and label[j] == 2:
|
98 |
+
head.append(i)
|
99 |
+
tail.append(j)
|
100 |
+
|
101 |
+
relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
|
102 |
+
relations[0, 0] = len(head)
|
103 |
+
relations[1:len(head) + 1, 0] = head
|
104 |
+
relations[0, 1] = len(tail)
|
105 |
+
relations[1:len(tail) + 1, 1] = tail
|
106 |
+
|
107 |
+
entities = np.expand_dims(entities, axis=0)
|
108 |
+
entities = np.repeat(entities, batch_size, axis=0)
|
109 |
+
relations = np.expand_dims(relations, axis=0)
|
110 |
+
relations = np.repeat(relations, batch_size, axis=0)
|
111 |
+
|
112 |
+
# remove ocr_info segment_offset_id and label in ser input
|
113 |
+
if isinstance(ser_inputs[0], paddle.Tensor):
|
114 |
+
entities = paddle.to_tensor(entities)
|
115 |
+
relations = paddle.to_tensor(relations)
|
116 |
+
ser_inputs = ser_inputs[:5] + [entities, relations]
|
117 |
+
|
118 |
+
entity_idx_dict_batch = []
|
119 |
+
for b in range(batch_size):
|
120 |
+
entity_idx_dict_batch.append(entity_idx_dict)
|
121 |
+
return ser_inputs, entity_idx_dict_batch
|
122 |
+
|
123 |
+
|
124 |
+
class SerRePredictor(object):
|
125 |
+
def __init__(self, config, ser_config):
|
126 |
+
global_config = config['Global']
|
127 |
+
if "infer_mode" in global_config:
|
128 |
+
ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
|
129 |
+
|
130 |
+
self.ser_engine = SerPredictor(ser_config)
|
131 |
+
|
132 |
+
# init re model
|
133 |
+
|
134 |
+
# build post process
|
135 |
+
self.post_process_class = build_post_process(config['PostProcess'],
|
136 |
+
global_config)
|
137 |
+
|
138 |
+
# build model
|
139 |
+
self.model = build_model(config['Architecture'])
|
140 |
+
|
141 |
+
load_model(
|
142 |
+
config, self.model, model_type=config['Architecture']["model_type"])
|
143 |
+
|
144 |
+
self.model.eval()
|
145 |
+
|
146 |
+
def __call__(self, data):
|
147 |
+
ser_results, ser_inputs = self.ser_engine(data)
|
148 |
+
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
|
149 |
+
if self.model.backbone.use_visual_backbone is False:
|
150 |
+
re_input.pop(4)
|
151 |
+
preds = self.model(re_input)
|
152 |
+
post_result = self.post_process_class(
|
153 |
+
preds,
|
154 |
+
ser_results=ser_results,
|
155 |
+
entity_idx_dict_batch=entity_idx_dict_batch)
|
156 |
+
return post_result
|
157 |
+
|
158 |
+
|
159 |
+
def preprocess():
|
160 |
+
FLAGS = ReArgsParser().parse_args()
|
161 |
+
config = load_config(FLAGS.config)
|
162 |
+
config = merge_config(config, FLAGS.opt)
|
163 |
+
|
164 |
+
ser_config = load_config(FLAGS.config_ser)
|
165 |
+
ser_config = merge_config(ser_config, FLAGS.opt_ser)
|
166 |
+
|
167 |
+
logger = get_logger()
|
168 |
+
|
169 |
+
# check if set use_gpu=True in paddlepaddle cpu version
|
170 |
+
use_gpu = config['Global']['use_gpu']
|
171 |
+
|
172 |
+
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
173 |
+
device = paddle.set_device(device)
|
174 |
+
|
175 |
+
logger.info('{} re config {}'.format('*' * 10, '*' * 10))
|
176 |
+
print_dict(config, logger)
|
177 |
+
logger.info('\n')
|
178 |
+
logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
|
179 |
+
print_dict(ser_config, logger)
|
180 |
+
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
181 |
+
device))
|
182 |
+
return config, ser_config, device, logger
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
config, ser_config, device, logger = preprocess()
|
187 |
+
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
|
188 |
+
|
189 |
+
ser_re_engine = SerRePredictor(config, ser_config)
|
190 |
+
|
191 |
+
if config["Global"].get("infer_mode", None) is False:
|
192 |
+
data_dir = config['Eval']['dataset']['data_dir']
|
193 |
+
with open(config['Global']['infer_img'], "rb") as f:
|
194 |
+
infer_imgs = f.readlines()
|
195 |
+
else:
|
196 |
+
infer_imgs = get_image_file_list(config['Global']['infer_img'])
|
197 |
+
|
198 |
+
with open(
|
199 |
+
os.path.join(config['Global']['save_res_path'],
|
200 |
+
"infer_results.txt"),
|
201 |
+
"w",
|
202 |
+
encoding='utf-8') as fout:
|
203 |
+
for idx, info in enumerate(infer_imgs):
|
204 |
+
if config["Global"].get("infer_mode", None) is False:
|
205 |
+
data_line = info.decode('utf-8')
|
206 |
+
substr = data_line.strip("\n").split("\t")
|
207 |
+
img_path = os.path.join(data_dir, substr[0])
|
208 |
+
data = {'img_path': img_path, 'label': substr[1]}
|
209 |
+
else:
|
210 |
+
img_path = info
|
211 |
+
data = {'img_path': img_path}
|
212 |
+
|
213 |
+
save_img_path = os.path.join(
|
214 |
+
config['Global']['save_res_path'],
|
215 |
+
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
|
216 |
+
|
217 |
+
result = ser_re_engine(data)
|
218 |
+
result = result[0]
|
219 |
+
fout.write(img_path + "\t" + json.dumps(
|
220 |
+
result, ensure_ascii=False) + "\n")
|
221 |
+
img_res = draw_re_results(img_path, result)
|
222 |
+
cv2.imwrite(save_img_path, img_res)
|
223 |
+
|
224 |
+
logger.info("process: [{}/{}], save result to {}".format(
|
225 |
+
idx, len(infer_imgs), save_img_path))
|
tools/infer_rec.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import json
|
24 |
+
|
25 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append(__dir__)
|
27 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
28 |
+
|
29 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
30 |
+
|
31 |
+
import paddle
|
32 |
+
|
33 |
+
from ppocr.data import create_operators, transform
|
34 |
+
from ppocr.modeling.architectures import build_model
|
35 |
+
from ppocr.postprocess import build_post_process
|
36 |
+
from ppocr.utils.save_load import load_model
|
37 |
+
from ppocr.utils.utility import get_image_file_list
|
38 |
+
import tools.program as program
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
global_config = config['Global']
|
43 |
+
|
44 |
+
# build post process
|
45 |
+
post_process_class = build_post_process(config['PostProcess'],
|
46 |
+
global_config)
|
47 |
+
|
48 |
+
# build model
|
49 |
+
if hasattr(post_process_class, 'character'):
|
50 |
+
char_num = len(getattr(post_process_class, 'character'))
|
51 |
+
if config['Architecture']["algorithm"] in ["Distillation",
|
52 |
+
]: # distillation model
|
53 |
+
for key in config['Architecture']["Models"]:
|
54 |
+
if config['Architecture']['Models'][key]['Head'][
|
55 |
+
'name'] == 'MultiHead': # for multi head
|
56 |
+
out_channels_list = {}
|
57 |
+
if config['PostProcess'][
|
58 |
+
'name'] == 'DistillationSARLabelDecode':
|
59 |
+
char_num = char_num - 2
|
60 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
61 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
62 |
+
config['Architecture']['Models'][key]['Head'][
|
63 |
+
'out_channels_list'] = out_channels_list
|
64 |
+
else:
|
65 |
+
config['Architecture']["Models"][key]["Head"][
|
66 |
+
'out_channels'] = char_num
|
67 |
+
elif config['Architecture']['Head'][
|
68 |
+
'name'] == 'MultiHead': # for multi head loss
|
69 |
+
out_channels_list = {}
|
70 |
+
if config['PostProcess']['name'] == 'SARLabelDecode':
|
71 |
+
char_num = char_num - 2
|
72 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
73 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
74 |
+
config['Architecture']['Head'][
|
75 |
+
'out_channels_list'] = out_channels_list
|
76 |
+
else: # base rec model
|
77 |
+
config['Architecture']["Head"]['out_channels'] = char_num
|
78 |
+
|
79 |
+
model = build_model(config['Architecture'])
|
80 |
+
|
81 |
+
load_model(config, model)
|
82 |
+
|
83 |
+
# create data ops
|
84 |
+
transforms = []
|
85 |
+
for op in config['Eval']['dataset']['transforms']:
|
86 |
+
op_name = list(op)[0]
|
87 |
+
if 'Label' in op_name:
|
88 |
+
continue
|
89 |
+
elif op_name in ['RecResizeImg']:
|
90 |
+
op[op_name]['infer_mode'] = True
|
91 |
+
elif op_name == 'KeepKeys':
|
92 |
+
if config['Architecture']['algorithm'] == "SRN":
|
93 |
+
op[op_name]['keep_keys'] = [
|
94 |
+
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
95 |
+
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
96 |
+
]
|
97 |
+
elif config['Architecture']['algorithm'] == "SAR":
|
98 |
+
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
99 |
+
elif config['Architecture']['algorithm'] == "RobustScanner":
|
100 |
+
op[op_name][
|
101 |
+
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
102 |
+
else:
|
103 |
+
op[op_name]['keep_keys'] = ['image']
|
104 |
+
transforms.append(op)
|
105 |
+
global_config['infer_mode'] = True
|
106 |
+
ops = create_operators(transforms, global_config)
|
107 |
+
|
108 |
+
save_res_path = config['Global'].get('save_res_path',
|
109 |
+
"./output/rec/predicts_rec.txt")
|
110 |
+
if not os.path.exists(os.path.dirname(save_res_path)):
|
111 |
+
os.makedirs(os.path.dirname(save_res_path))
|
112 |
+
|
113 |
+
model.eval()
|
114 |
+
|
115 |
+
with open(save_res_path, "w") as fout:
|
116 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
117 |
+
logger.info("infer_img: {}".format(file))
|
118 |
+
with open(file, 'rb') as f:
|
119 |
+
img = f.read()
|
120 |
+
data = {'image': img}
|
121 |
+
batch = transform(data, ops)
|
122 |
+
if config['Architecture']['algorithm'] == "SRN":
|
123 |
+
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
124 |
+
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
125 |
+
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
126 |
+
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
127 |
+
|
128 |
+
others = [
|
129 |
+
paddle.to_tensor(encoder_word_pos_list),
|
130 |
+
paddle.to_tensor(gsrm_word_pos_list),
|
131 |
+
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
132 |
+
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
133 |
+
]
|
134 |
+
if config['Architecture']['algorithm'] == "SAR":
|
135 |
+
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
136 |
+
img_metas = [paddle.to_tensor(valid_ratio)]
|
137 |
+
if config['Architecture']['algorithm'] == "RobustScanner":
|
138 |
+
valid_ratio = np.expand_dims(batch[1], axis=0)
|
139 |
+
word_positons = np.expand_dims(batch[2], axis=0)
|
140 |
+
img_metas = [
|
141 |
+
paddle.to_tensor(valid_ratio),
|
142 |
+
paddle.to_tensor(word_positons),
|
143 |
+
]
|
144 |
+
if config['Architecture']['algorithm'] == "CAN":
|
145 |
+
image_mask = paddle.ones(
|
146 |
+
(np.expand_dims(
|
147 |
+
batch[0], axis=0).shape), dtype='float32')
|
148 |
+
label = paddle.ones((1, 36), dtype='int64')
|
149 |
+
images = np.expand_dims(batch[0], axis=0)
|
150 |
+
images = paddle.to_tensor(images)
|
151 |
+
if config['Architecture']['algorithm'] == "SRN":
|
152 |
+
preds = model(images, others)
|
153 |
+
elif config['Architecture']['algorithm'] == "SAR":
|
154 |
+
preds = model(images, img_metas)
|
155 |
+
elif config['Architecture']['algorithm'] == "RobustScanner":
|
156 |
+
preds = model(images, img_metas)
|
157 |
+
elif config['Architecture']['algorithm'] == "CAN":
|
158 |
+
preds = model([images, image_mask, label])
|
159 |
+
else:
|
160 |
+
preds = model(images)
|
161 |
+
post_result = post_process_class(preds)
|
162 |
+
info = None
|
163 |
+
if isinstance(post_result, dict):
|
164 |
+
rec_info = dict()
|
165 |
+
for key in post_result:
|
166 |
+
if len(post_result[key][0]) >= 2:
|
167 |
+
rec_info[key] = {
|
168 |
+
"label": post_result[key][0][0],
|
169 |
+
"score": float(post_result[key][0][1]),
|
170 |
+
}
|
171 |
+
info = json.dumps(rec_info, ensure_ascii=False)
|
172 |
+
elif isinstance(post_result, list) and isinstance(post_result[0],
|
173 |
+
int):
|
174 |
+
# for RFLearning CNT branch
|
175 |
+
info = str(post_result[0])
|
176 |
+
else:
|
177 |
+
if len(post_result[0]) >= 2:
|
178 |
+
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
179 |
+
|
180 |
+
if info is not None:
|
181 |
+
logger.info("\t result: {}".format(info))
|
182 |
+
fout.write(file + "\t" + info + "\n")
|
183 |
+
logger.info("success!")
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == '__main__':
|
187 |
+
config, device, logger, vdl_writer = program.preprocess()
|
188 |
+
main()
|
tools/infer_sr.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import json
|
24 |
+
from PIL import Image
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
sys.path.insert(0, __dir__)
|
29 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
30 |
+
|
31 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
32 |
+
|
33 |
+
import paddle
|
34 |
+
|
35 |
+
from ppocr.data import create_operators, transform
|
36 |
+
from ppocr.modeling.architectures import build_model
|
37 |
+
from ppocr.postprocess import build_post_process
|
38 |
+
from ppocr.utils.save_load import load_model
|
39 |
+
from ppocr.utils.utility import get_image_file_list
|
40 |
+
import tools.program as program
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
global_config = config['Global']
|
45 |
+
|
46 |
+
# build post process
|
47 |
+
post_process_class = build_post_process(config['PostProcess'],
|
48 |
+
global_config)
|
49 |
+
|
50 |
+
# sr transform
|
51 |
+
config['Architecture']["Transform"]['infer_mode'] = True
|
52 |
+
|
53 |
+
model = build_model(config['Architecture'])
|
54 |
+
|
55 |
+
load_model(config, model)
|
56 |
+
|
57 |
+
# create data ops
|
58 |
+
transforms = []
|
59 |
+
for op in config['Eval']['dataset']['transforms']:
|
60 |
+
op_name = list(op)[0]
|
61 |
+
if 'Label' in op_name:
|
62 |
+
continue
|
63 |
+
elif op_name in ['SRResize']:
|
64 |
+
op[op_name]['infer_mode'] = True
|
65 |
+
elif op_name == 'KeepKeys':
|
66 |
+
op[op_name]['keep_keys'] = ['img_lr']
|
67 |
+
transforms.append(op)
|
68 |
+
global_config['infer_mode'] = True
|
69 |
+
ops = create_operators(transforms, global_config)
|
70 |
+
|
71 |
+
save_visual_path = config['Global'].get('save_visual', "infer_result/")
|
72 |
+
if not os.path.exists(os.path.dirname(save_visual_path)):
|
73 |
+
os.makedirs(os.path.dirname(save_visual_path))
|
74 |
+
|
75 |
+
model.eval()
|
76 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
77 |
+
logger.info("infer_img: {}".format(file))
|
78 |
+
img = Image.open(file).convert("RGB")
|
79 |
+
data = {'image_lr': img}
|
80 |
+
batch = transform(data, ops)
|
81 |
+
images = np.expand_dims(batch[0], axis=0)
|
82 |
+
images = paddle.to_tensor(images)
|
83 |
+
|
84 |
+
preds = model(images)
|
85 |
+
sr_img = preds["sr_img"][0]
|
86 |
+
lr_img = preds["lr_img"][0]
|
87 |
+
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
88 |
+
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
89 |
+
img_name_pure = os.path.split(file)[-1]
|
90 |
+
cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
|
91 |
+
fm_sr[:, :, ::-1])
|
92 |
+
logger.info("The visualized image saved in infer_result/sr_{}".format(
|
93 |
+
img_name_pure))
|
94 |
+
|
95 |
+
logger.info("success!")
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
config, device, logger, vdl_writer = program.preprocess()
|
100 |
+
main()
|
tools/infer_table.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import json
|
24 |
+
|
25 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append(__dir__)
|
27 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
28 |
+
|
29 |
+
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
30 |
+
|
31 |
+
import paddle
|
32 |
+
from paddle.jit import to_static
|
33 |
+
|
34 |
+
from ppocr.data import create_operators, transform
|
35 |
+
from ppocr.modeling.architectures import build_model
|
36 |
+
from ppocr.postprocess import build_post_process
|
37 |
+
from ppocr.utils.save_load import load_model
|
38 |
+
from ppocr.utils.utility import get_image_file_list
|
39 |
+
from ppocr.utils.visual import draw_rectangle
|
40 |
+
from tools.infer.utility import draw_boxes
|
41 |
+
import tools.program as program
|
42 |
+
import cv2
|
43 |
+
|
44 |
+
|
45 |
+
@paddle.no_grad()
|
46 |
+
def main(config, device, logger, vdl_writer):
|
47 |
+
global_config = config['Global']
|
48 |
+
|
49 |
+
# build post process
|
50 |
+
post_process_class = build_post_process(config['PostProcess'],
|
51 |
+
global_config)
|
52 |
+
|
53 |
+
# build model
|
54 |
+
if hasattr(post_process_class, 'character'):
|
55 |
+
config['Architecture']["Head"]['out_channels'] = len(
|
56 |
+
getattr(post_process_class, 'character'))
|
57 |
+
|
58 |
+
model = build_model(config['Architecture'])
|
59 |
+
algorithm = config['Architecture']['algorithm']
|
60 |
+
|
61 |
+
load_model(config, model)
|
62 |
+
|
63 |
+
# create data ops
|
64 |
+
transforms = []
|
65 |
+
for op in config['Eval']['dataset']['transforms']:
|
66 |
+
op_name = list(op)[0]
|
67 |
+
if 'Encode' in op_name:
|
68 |
+
continue
|
69 |
+
if op_name == 'KeepKeys':
|
70 |
+
op[op_name]['keep_keys'] = ['image', 'shape']
|
71 |
+
transforms.append(op)
|
72 |
+
|
73 |
+
global_config['infer_mode'] = True
|
74 |
+
ops = create_operators(transforms, global_config)
|
75 |
+
|
76 |
+
save_res_path = config['Global']['save_res_path']
|
77 |
+
os.makedirs(save_res_path, exist_ok=True)
|
78 |
+
|
79 |
+
model.eval()
|
80 |
+
with open(
|
81 |
+
os.path.join(save_res_path, 'infer.txt'), mode='w',
|
82 |
+
encoding='utf-8') as f_w:
|
83 |
+
for file in get_image_file_list(config['Global']['infer_img']):
|
84 |
+
logger.info("infer_img: {}".format(file))
|
85 |
+
with open(file, 'rb') as f:
|
86 |
+
img = f.read()
|
87 |
+
data = {'image': img}
|
88 |
+
batch = transform(data, ops)
|
89 |
+
images = np.expand_dims(batch[0], axis=0)
|
90 |
+
shape_list = np.expand_dims(batch[1], axis=0)
|
91 |
+
|
92 |
+
images = paddle.to_tensor(images)
|
93 |
+
preds = model(images)
|
94 |
+
post_result = post_process_class(preds, [shape_list])
|
95 |
+
|
96 |
+
structure_str_list = post_result['structure_batch_list'][0]
|
97 |
+
bbox_list = post_result['bbox_batch_list'][0]
|
98 |
+
structure_str_list = structure_str_list[0]
|
99 |
+
structure_str_list = [
|
100 |
+
'<html>', '<body>', '<table>'
|
101 |
+
] + structure_str_list + ['</table>', '</body>', '</html>']
|
102 |
+
bbox_list_str = json.dumps(bbox_list.tolist())
|
103 |
+
|
104 |
+
logger.info("result: {}, {}".format(structure_str_list,
|
105 |
+
bbox_list_str))
|
106 |
+
f_w.write("result: {}, {}\n".format(structure_str_list,
|
107 |
+
bbox_list_str))
|
108 |
+
|
109 |
+
if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
|
110 |
+
img = draw_rectangle(file, bbox_list)
|
111 |
+
else:
|
112 |
+
img = draw_boxes(cv2.imread(file), bbox_list)
|
113 |
+
cv2.imwrite(
|
114 |
+
os.path.join(save_res_path, os.path.basename(file)), img)
|
115 |
+
logger.info('save result to {}'.format(save_res_path))
|
116 |
+
logger.info("success!")
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
config, device, logger, vdl_writer = program.preprocess()
|
121 |
+
main(config, device, logger, vdl_writer)
|
tools/program.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
import platform
|
22 |
+
import yaml
|
23 |
+
import time
|
24 |
+
import datetime
|
25 |
+
import paddle
|
26 |
+
import paddle.distributed as dist
|
27 |
+
from tqdm import tqdm
|
28 |
+
import cv2
|
29 |
+
import numpy as np
|
30 |
+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
31 |
+
|
32 |
+
from ppocr.utils.stats import TrainingStats
|
33 |
+
from ppocr.utils.save_load import save_model
|
34 |
+
from ppocr.utils.utility import print_dict, AverageMeter
|
35 |
+
from ppocr.utils.logging import get_logger
|
36 |
+
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
|
37 |
+
from ppocr.utils import profiler
|
38 |
+
from ppocr.data import build_dataloader
|
39 |
+
|
40 |
+
|
41 |
+
class ArgsParser(ArgumentParser):
|
42 |
+
def __init__(self):
|
43 |
+
super(ArgsParser, self).__init__(
|
44 |
+
formatter_class=RawDescriptionHelpFormatter)
|
45 |
+
self.add_argument("-c", "--config", help="configuration file to use")
|
46 |
+
self.add_argument(
|
47 |
+
"-o", "--opt", nargs='+', help="set configuration options")
|
48 |
+
self.add_argument(
|
49 |
+
'-p',
|
50 |
+
'--profiler_options',
|
51 |
+
type=str,
|
52 |
+
default=None,
|
53 |
+
help='The option of profiler, which should be in format ' \
|
54 |
+
'\"key1=value1;key2=value2;key3=value3\".'
|
55 |
+
)
|
56 |
+
|
57 |
+
def parse_args(self, argv=None):
|
58 |
+
args = super(ArgsParser, self).parse_args(argv)
|
59 |
+
assert args.config is not None, \
|
60 |
+
"Please specify --config=configure_file_path."
|
61 |
+
args.opt = self._parse_opt(args.opt)
|
62 |
+
return args
|
63 |
+
|
64 |
+
def _parse_opt(self, opts):
|
65 |
+
config = {}
|
66 |
+
if not opts:
|
67 |
+
return config
|
68 |
+
for s in opts:
|
69 |
+
s = s.strip()
|
70 |
+
k, v = s.split('=')
|
71 |
+
config[k] = yaml.load(v, Loader=yaml.Loader)
|
72 |
+
return config
|
73 |
+
|
74 |
+
|
75 |
+
def load_config(file_path):
|
76 |
+
"""
|
77 |
+
Load config from yml/yaml file.
|
78 |
+
Args:
|
79 |
+
file_path (str): Path of the config file to be loaded.
|
80 |
+
Returns: global config
|
81 |
+
"""
|
82 |
+
_, ext = os.path.splitext(file_path)
|
83 |
+
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
84 |
+
config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
|
85 |
+
return config
|
86 |
+
|
87 |
+
|
88 |
+
def merge_config(config, opts):
|
89 |
+
"""
|
90 |
+
Merge config into global config.
|
91 |
+
Args:
|
92 |
+
config (dict): Config to be merged.
|
93 |
+
Returns: global config
|
94 |
+
"""
|
95 |
+
for key, value in opts.items():
|
96 |
+
if "." not in key:
|
97 |
+
if isinstance(value, dict) and key in config:
|
98 |
+
config[key].update(value)
|
99 |
+
else:
|
100 |
+
config[key] = value
|
101 |
+
else:
|
102 |
+
sub_keys = key.split('.')
|
103 |
+
assert (
|
104 |
+
sub_keys[0] in config
|
105 |
+
), "the sub_keys can only be one of global_config: {}, but get: " \
|
106 |
+
"{}, please check your running command".format(
|
107 |
+
config.keys(), sub_keys[0])
|
108 |
+
cur = config[sub_keys[0]]
|
109 |
+
for idx, sub_key in enumerate(sub_keys[1:]):
|
110 |
+
if idx == len(sub_keys) - 2:
|
111 |
+
cur[sub_key] = value
|
112 |
+
else:
|
113 |
+
cur = cur[sub_key]
|
114 |
+
return config
|
115 |
+
|
116 |
+
|
117 |
+
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
|
118 |
+
"""
|
119 |
+
Log error and exit when set use_gpu=true in paddlepaddle
|
120 |
+
cpu version.
|
121 |
+
"""
|
122 |
+
err = "Config {} cannot be set as true while your paddle " \
|
123 |
+
"is not compiled with {} ! \nPlease try: \n" \
|
124 |
+
"\t1. Install paddlepaddle to run model on {} \n" \
|
125 |
+
"\t2. Set {} as false in config file to run " \
|
126 |
+
"model on CPU"
|
127 |
+
|
128 |
+
try:
|
129 |
+
if use_gpu and use_xpu:
|
130 |
+
print("use_xpu and use_gpu can not both be ture.")
|
131 |
+
if use_gpu and not paddle.is_compiled_with_cuda():
|
132 |
+
print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
|
133 |
+
sys.exit(1)
|
134 |
+
if use_xpu and not paddle.device.is_compiled_with_xpu():
|
135 |
+
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
|
136 |
+
sys.exit(1)
|
137 |
+
if use_npu:
|
138 |
+
if int(paddle.version.major) != 0 and int(
|
139 |
+
paddle.version.major) <= 2 and int(
|
140 |
+
paddle.version.minor) <= 4:
|
141 |
+
if not paddle.device.is_compiled_with_npu():
|
142 |
+
print(err.format("use_npu", "npu", "npu", "use_npu"))
|
143 |
+
sys.exit(1)
|
144 |
+
# is_compiled_with_npu() has been updated after paddle-2.4
|
145 |
+
else:
|
146 |
+
if not paddle.device.is_compiled_with_custom_device("npu"):
|
147 |
+
print(err.format("use_npu", "npu", "npu", "use_npu"))
|
148 |
+
sys.exit(1)
|
149 |
+
if use_mlu and not paddle.device.is_compiled_with_mlu():
|
150 |
+
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
|
151 |
+
sys.exit(1)
|
152 |
+
except Exception as e:
|
153 |
+
pass
|
154 |
+
|
155 |
+
|
156 |
+
def to_float32(preds):
|
157 |
+
if isinstance(preds, dict):
|
158 |
+
for k in preds:
|
159 |
+
if isinstance(preds[k], dict) or isinstance(preds[k], list):
|
160 |
+
preds[k] = to_float32(preds[k])
|
161 |
+
elif isinstance(preds[k], paddle.Tensor):
|
162 |
+
preds[k] = preds[k].astype(paddle.float32)
|
163 |
+
elif isinstance(preds, list):
|
164 |
+
for k in range(len(preds)):
|
165 |
+
if isinstance(preds[k], dict):
|
166 |
+
preds[k] = to_float32(preds[k])
|
167 |
+
elif isinstance(preds[k], list):
|
168 |
+
preds[k] = to_float32(preds[k])
|
169 |
+
elif isinstance(preds[k], paddle.Tensor):
|
170 |
+
preds[k] = preds[k].astype(paddle.float32)
|
171 |
+
elif isinstance(preds, paddle.Tensor):
|
172 |
+
preds = preds.astype(paddle.float32)
|
173 |
+
return preds
|
174 |
+
|
175 |
+
|
176 |
+
def train(config,
|
177 |
+
train_dataloader,
|
178 |
+
valid_dataloader,
|
179 |
+
device,
|
180 |
+
model,
|
181 |
+
loss_class,
|
182 |
+
optimizer,
|
183 |
+
lr_scheduler,
|
184 |
+
post_process_class,
|
185 |
+
eval_class,
|
186 |
+
pre_best_model_dict,
|
187 |
+
logger,
|
188 |
+
log_writer=None,
|
189 |
+
scaler=None,
|
190 |
+
amp_level='O2',
|
191 |
+
amp_custom_black_list=[]):
|
192 |
+
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
193 |
+
False)
|
194 |
+
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
|
195 |
+
log_smooth_window = config['Global']['log_smooth_window']
|
196 |
+
epoch_num = config['Global']['epoch_num']
|
197 |
+
print_batch_step = config['Global']['print_batch_step']
|
198 |
+
eval_batch_step = config['Global']['eval_batch_step']
|
199 |
+
profiler_options = config['profiler_options']
|
200 |
+
|
201 |
+
global_step = 0
|
202 |
+
if 'global_step' in pre_best_model_dict:
|
203 |
+
global_step = pre_best_model_dict['global_step']
|
204 |
+
start_eval_step = 0
|
205 |
+
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
206 |
+
start_eval_step = eval_batch_step[0]
|
207 |
+
eval_batch_step = eval_batch_step[1]
|
208 |
+
if len(valid_dataloader) == 0:
|
209 |
+
logger.info(
|
210 |
+
'No Images in eval dataset, evaluation during training ' \
|
211 |
+
'will be disabled'
|
212 |
+
)
|
213 |
+
start_eval_step = 1e111
|
214 |
+
logger.info(
|
215 |
+
"During the training process, after the {}th iteration, " \
|
216 |
+
"an evaluation is run every {} iterations".
|
217 |
+
format(start_eval_step, eval_batch_step))
|
218 |
+
save_epoch_step = config['Global']['save_epoch_step']
|
219 |
+
save_model_dir = config['Global']['save_model_dir']
|
220 |
+
if not os.path.exists(save_model_dir):
|
221 |
+
os.makedirs(save_model_dir)
|
222 |
+
main_indicator = eval_class.main_indicator
|
223 |
+
best_model_dict = {main_indicator: 0}
|
224 |
+
best_model_dict.update(pre_best_model_dict)
|
225 |
+
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
226 |
+
model_average = False
|
227 |
+
model.train()
|
228 |
+
|
229 |
+
use_srn = config['Architecture']['algorithm'] == "SRN"
|
230 |
+
extra_input_models = [
|
231 |
+
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
232 |
+
"RobustScanner", "RFL", 'DRRG'
|
233 |
+
]
|
234 |
+
extra_input = False
|
235 |
+
if config['Architecture']['algorithm'] == 'Distillation':
|
236 |
+
for key in config['Architecture']["Models"]:
|
237 |
+
extra_input = extra_input or config['Architecture']['Models'][key][
|
238 |
+
'algorithm'] in extra_input_models
|
239 |
+
else:
|
240 |
+
extra_input = config['Architecture']['algorithm'] in extra_input_models
|
241 |
+
try:
|
242 |
+
model_type = config['Architecture']['model_type']
|
243 |
+
except:
|
244 |
+
model_type = None
|
245 |
+
|
246 |
+
algorithm = config['Architecture']['algorithm']
|
247 |
+
|
248 |
+
start_epoch = best_model_dict[
|
249 |
+
'start_epoch'] if 'start_epoch' in best_model_dict else 1
|
250 |
+
|
251 |
+
total_samples = 0
|
252 |
+
train_reader_cost = 0.0
|
253 |
+
train_batch_cost = 0.0
|
254 |
+
reader_start = time.time()
|
255 |
+
eta_meter = AverageMeter()
|
256 |
+
|
257 |
+
max_iter = len(train_dataloader) - 1 if platform.system(
|
258 |
+
) == "Windows" else len(train_dataloader)
|
259 |
+
|
260 |
+
for epoch in range(start_epoch, epoch_num + 1):
|
261 |
+
if train_dataloader.dataset.need_reset:
|
262 |
+
train_dataloader = build_dataloader(
|
263 |
+
config, 'Train', device, logger, seed=epoch)
|
264 |
+
max_iter = len(train_dataloader) - 1 if platform.system(
|
265 |
+
) == "Windows" else len(train_dataloader)
|
266 |
+
|
267 |
+
for idx, batch in enumerate(train_dataloader):
|
268 |
+
profiler.add_profiler_step(profiler_options)
|
269 |
+
train_reader_cost += time.time() - reader_start
|
270 |
+
if idx >= max_iter:
|
271 |
+
break
|
272 |
+
lr = optimizer.get_lr()
|
273 |
+
images = batch[0]
|
274 |
+
if use_srn:
|
275 |
+
model_average = True
|
276 |
+
# use amp
|
277 |
+
if scaler:
|
278 |
+
with paddle.amp.auto_cast(
|
279 |
+
level=amp_level,
|
280 |
+
custom_black_list=amp_custom_black_list):
|
281 |
+
if model_type == 'table' or extra_input:
|
282 |
+
preds = model(images, data=batch[1:])
|
283 |
+
elif model_type in ["kie"]:
|
284 |
+
preds = model(batch)
|
285 |
+
elif algorithm in ['CAN']:
|
286 |
+
preds = model(batch[:3])
|
287 |
+
else:
|
288 |
+
preds = model(images)
|
289 |
+
preds = to_float32(preds)
|
290 |
+
loss = loss_class(preds, batch)
|
291 |
+
avg_loss = loss['loss']
|
292 |
+
scaled_avg_loss = scaler.scale(avg_loss)
|
293 |
+
scaled_avg_loss.backward()
|
294 |
+
scaler.minimize(optimizer, scaled_avg_loss)
|
295 |
+
else:
|
296 |
+
if model_type == 'table' or extra_input:
|
297 |
+
preds = model(images, data=batch[1:])
|
298 |
+
elif model_type in ["kie", 'sr']:
|
299 |
+
preds = model(batch)
|
300 |
+
elif algorithm in ['CAN']:
|
301 |
+
preds = model(batch[:3])
|
302 |
+
else:
|
303 |
+
preds = model(images)
|
304 |
+
loss = loss_class(preds, batch)
|
305 |
+
avg_loss = loss['loss']
|
306 |
+
avg_loss.backward()
|
307 |
+
optimizer.step()
|
308 |
+
|
309 |
+
optimizer.clear_grad()
|
310 |
+
|
311 |
+
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
|
312 |
+
batch = [item.numpy() for item in batch]
|
313 |
+
if model_type in ['kie', 'sr']:
|
314 |
+
eval_class(preds, batch)
|
315 |
+
elif model_type in ['table']:
|
316 |
+
post_result = post_process_class(preds, batch)
|
317 |
+
eval_class(post_result, batch)
|
318 |
+
elif algorithm in ['CAN']:
|
319 |
+
model_type = 'can'
|
320 |
+
eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
|
321 |
+
else:
|
322 |
+
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
|
323 |
+
]: # for multi head loss
|
324 |
+
post_result = post_process_class(
|
325 |
+
preds['ctc'], batch[1]) # for CTC head out
|
326 |
+
elif config['Loss']['name'] in ['VLLoss']:
|
327 |
+
post_result = post_process_class(preds, batch[1],
|
328 |
+
batch[-1])
|
329 |
+
else:
|
330 |
+
post_result = post_process_class(preds, batch[1])
|
331 |
+
eval_class(post_result, batch)
|
332 |
+
metric = eval_class.get_metric()
|
333 |
+
train_stats.update(metric)
|
334 |
+
|
335 |
+
train_batch_time = time.time() - reader_start
|
336 |
+
train_batch_cost += train_batch_time
|
337 |
+
eta_meter.update(train_batch_time)
|
338 |
+
global_step += 1
|
339 |
+
total_samples += len(images)
|
340 |
+
|
341 |
+
if not isinstance(lr_scheduler, float):
|
342 |
+
lr_scheduler.step()
|
343 |
+
|
344 |
+
# logger and visualdl
|
345 |
+
stats = {k: v.numpy().mean() for k, v in loss.items()}
|
346 |
+
stats['lr'] = lr
|
347 |
+
train_stats.update(stats)
|
348 |
+
|
349 |
+
if log_writer is not None and dist.get_rank() == 0:
|
350 |
+
log_writer.log_metrics(
|
351 |
+
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
352 |
+
|
353 |
+
if dist.get_rank() == 0 and (
|
354 |
+
(global_step > 0 and global_step % print_batch_step == 0) or
|
355 |
+
(idx >= len(train_dataloader) - 1)):
|
356 |
+
logs = train_stats.log()
|
357 |
+
|
358 |
+
eta_sec = ((epoch_num + 1 - epoch) * \
|
359 |
+
len(train_dataloader) - idx - 1) * eta_meter.avg
|
360 |
+
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
|
361 |
+
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
|
362 |
+
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
|
363 |
+
'ips: {:.5f} samples/s, eta: {}'.format(
|
364 |
+
epoch, epoch_num, global_step, logs,
|
365 |
+
train_reader_cost / print_batch_step,
|
366 |
+
train_batch_cost / print_batch_step,
|
367 |
+
total_samples / print_batch_step,
|
368 |
+
total_samples / train_batch_cost, eta_sec_format)
|
369 |
+
logger.info(strs)
|
370 |
+
|
371 |
+
total_samples = 0
|
372 |
+
train_reader_cost = 0.0
|
373 |
+
train_batch_cost = 0.0
|
374 |
+
# eval
|
375 |
+
if global_step > start_eval_step and \
|
376 |
+
(global_step - start_eval_step) % eval_batch_step == 0 \
|
377 |
+
and dist.get_rank() == 0:
|
378 |
+
if model_average:
|
379 |
+
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
380 |
+
0.15,
|
381 |
+
parameters=model.parameters(),
|
382 |
+
min_average_window=10000,
|
383 |
+
max_average_window=15625)
|
384 |
+
Model_Average.apply()
|
385 |
+
cur_metric = eval(
|
386 |
+
model,
|
387 |
+
valid_dataloader,
|
388 |
+
post_process_class,
|
389 |
+
eval_class,
|
390 |
+
model_type,
|
391 |
+
extra_input=extra_input,
|
392 |
+
scaler=scaler,
|
393 |
+
amp_level=amp_level,
|
394 |
+
amp_custom_black_list=amp_custom_black_list)
|
395 |
+
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
396 |
+
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
397 |
+
logger.info(cur_metric_str)
|
398 |
+
|
399 |
+
# logger metric
|
400 |
+
if log_writer is not None:
|
401 |
+
log_writer.log_metrics(
|
402 |
+
metrics=cur_metric, prefix="EVAL", step=global_step)
|
403 |
+
|
404 |
+
if cur_metric[main_indicator] >= best_model_dict[
|
405 |
+
main_indicator]:
|
406 |
+
best_model_dict.update(cur_metric)
|
407 |
+
best_model_dict['best_epoch'] = epoch
|
408 |
+
save_model(
|
409 |
+
model,
|
410 |
+
optimizer,
|
411 |
+
save_model_dir,
|
412 |
+
logger,
|
413 |
+
config,
|
414 |
+
is_best=True,
|
415 |
+
prefix='best_accuracy',
|
416 |
+
best_model_dict=best_model_dict,
|
417 |
+
epoch=epoch,
|
418 |
+
global_step=global_step)
|
419 |
+
best_str = 'best metric, {}'.format(', '.join([
|
420 |
+
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
421 |
+
]))
|
422 |
+
logger.info(best_str)
|
423 |
+
# logger best metric
|
424 |
+
if log_writer is not None:
|
425 |
+
log_writer.log_metrics(
|
426 |
+
metrics={
|
427 |
+
"best_{}".format(main_indicator):
|
428 |
+
best_model_dict[main_indicator]
|
429 |
+
},
|
430 |
+
prefix="EVAL",
|
431 |
+
step=global_step)
|
432 |
+
|
433 |
+
log_writer.log_model(
|
434 |
+
is_best=True,
|
435 |
+
prefix="best_accuracy",
|
436 |
+
metadata=best_model_dict)
|
437 |
+
|
438 |
+
reader_start = time.time()
|
439 |
+
if dist.get_rank() == 0:
|
440 |
+
save_model(
|
441 |
+
model,
|
442 |
+
optimizer,
|
443 |
+
save_model_dir,
|
444 |
+
logger,
|
445 |
+
config,
|
446 |
+
is_best=False,
|
447 |
+
prefix='latest',
|
448 |
+
best_model_dict=best_model_dict,
|
449 |
+
epoch=epoch,
|
450 |
+
global_step=global_step)
|
451 |
+
|
452 |
+
if log_writer is not None:
|
453 |
+
log_writer.log_model(is_best=False, prefix="latest")
|
454 |
+
|
455 |
+
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
456 |
+
save_model(
|
457 |
+
model,
|
458 |
+
optimizer,
|
459 |
+
save_model_dir,
|
460 |
+
logger,
|
461 |
+
config,
|
462 |
+
is_best=False,
|
463 |
+
prefix='iter_epoch_{}'.format(epoch),
|
464 |
+
best_model_dict=best_model_dict,
|
465 |
+
epoch=epoch,
|
466 |
+
global_step=global_step)
|
467 |
+
if log_writer is not None:
|
468 |
+
log_writer.log_model(
|
469 |
+
is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
470 |
+
|
471 |
+
best_str = 'best metric, {}'.format(', '.join(
|
472 |
+
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
473 |
+
logger.info(best_str)
|
474 |
+
if dist.get_rank() == 0 and log_writer is not None:
|
475 |
+
log_writer.close()
|
476 |
+
return
|
477 |
+
|
478 |
+
|
479 |
+
def eval(model,
|
480 |
+
valid_dataloader,
|
481 |
+
post_process_class,
|
482 |
+
eval_class,
|
483 |
+
model_type=None,
|
484 |
+
extra_input=False,
|
485 |
+
scaler=None,
|
486 |
+
amp_level='O2',
|
487 |
+
amp_custom_black_list=[]):
|
488 |
+
model.eval()
|
489 |
+
with paddle.no_grad():
|
490 |
+
total_frame = 0.0
|
491 |
+
total_time = 0.0
|
492 |
+
pbar = tqdm(
|
493 |
+
total=len(valid_dataloader),
|
494 |
+
desc='eval model:',
|
495 |
+
position=0,
|
496 |
+
leave=True)
|
497 |
+
max_iter = len(valid_dataloader) - 1 if platform.system(
|
498 |
+
) == "Windows" else len(valid_dataloader)
|
499 |
+
sum_images = 0
|
500 |
+
for idx, batch in enumerate(valid_dataloader):
|
501 |
+
if idx >= max_iter:
|
502 |
+
break
|
503 |
+
images = batch[0]
|
504 |
+
start = time.time()
|
505 |
+
|
506 |
+
# use amp
|
507 |
+
if scaler:
|
508 |
+
with paddle.amp.auto_cast(
|
509 |
+
level=amp_level,
|
510 |
+
custom_black_list=amp_custom_black_list):
|
511 |
+
if model_type == 'table' or extra_input:
|
512 |
+
preds = model(images, data=batch[1:])
|
513 |
+
elif model_type in ["kie"]:
|
514 |
+
preds = model(batch)
|
515 |
+
elif model_type in ['can']:
|
516 |
+
preds = model(batch[:3])
|
517 |
+
elif model_type in ['sr']:
|
518 |
+
preds = model(batch)
|
519 |
+
sr_img = preds["sr_img"]
|
520 |
+
lr_img = preds["lr_img"]
|
521 |
+
else:
|
522 |
+
preds = model(images)
|
523 |
+
preds = to_float32(preds)
|
524 |
+
else:
|
525 |
+
if model_type == 'table' or extra_input:
|
526 |
+
preds = model(images, data=batch[1:])
|
527 |
+
elif model_type in ["kie"]:
|
528 |
+
preds = model(batch)
|
529 |
+
elif model_type in ['can']:
|
530 |
+
preds = model(batch[:3])
|
531 |
+
elif model_type in ['sr']:
|
532 |
+
preds = model(batch)
|
533 |
+
sr_img = preds["sr_img"]
|
534 |
+
lr_img = preds["lr_img"]
|
535 |
+
else:
|
536 |
+
preds = model(images)
|
537 |
+
|
538 |
+
batch_numpy = []
|
539 |
+
for item in batch:
|
540 |
+
if isinstance(item, paddle.Tensor):
|
541 |
+
batch_numpy.append(item.numpy())
|
542 |
+
else:
|
543 |
+
batch_numpy.append(item)
|
544 |
+
# Obtain usable results from post-processing methods
|
545 |
+
total_time += time.time() - start
|
546 |
+
# Evaluate the results of the current batch
|
547 |
+
if model_type in ['table', 'kie']:
|
548 |
+
if post_process_class is None:
|
549 |
+
eval_class(preds, batch_numpy)
|
550 |
+
else:
|
551 |
+
post_result = post_process_class(preds, batch_numpy)
|
552 |
+
eval_class(post_result, batch_numpy)
|
553 |
+
elif model_type in ['sr']:
|
554 |
+
eval_class(preds, batch_numpy)
|
555 |
+
elif model_type in ['can']:
|
556 |
+
eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
|
557 |
+
else:
|
558 |
+
post_result = post_process_class(preds, batch_numpy[1])
|
559 |
+
eval_class(post_result, batch_numpy)
|
560 |
+
|
561 |
+
pbar.update(1)
|
562 |
+
total_frame += len(images)
|
563 |
+
sum_images += 1
|
564 |
+
# Get final metric,eg. acc or hmean
|
565 |
+
metric = eval_class.get_metric()
|
566 |
+
|
567 |
+
pbar.close()
|
568 |
+
model.train()
|
569 |
+
metric['fps'] = total_frame / total_time
|
570 |
+
return metric
|
571 |
+
|
572 |
+
|
573 |
+
def update_center(char_center, post_result, preds):
|
574 |
+
result, label = post_result
|
575 |
+
feats, logits = preds
|
576 |
+
logits = paddle.argmax(logits, axis=-1)
|
577 |
+
feats = feats.numpy()
|
578 |
+
logits = logits.numpy()
|
579 |
+
|
580 |
+
for idx_sample in range(len(label)):
|
581 |
+
if result[idx_sample][0] == label[idx_sample][0]:
|
582 |
+
feat = feats[idx_sample]
|
583 |
+
logit = logits[idx_sample]
|
584 |
+
for idx_time in range(len(logit)):
|
585 |
+
index = logit[idx_time]
|
586 |
+
if index in char_center.keys():
|
587 |
+
char_center[index][0] = (
|
588 |
+
char_center[index][0] * char_center[index][1] +
|
589 |
+
feat[idx_time]) / (char_center[index][1] + 1)
|
590 |
+
char_center[index][1] += 1
|
591 |
+
else:
|
592 |
+
char_center[index] = [feat[idx_time], 1]
|
593 |
+
return char_center
|
594 |
+
|
595 |
+
|
596 |
+
def get_center(model, eval_dataloader, post_process_class):
|
597 |
+
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
|
598 |
+
max_iter = len(eval_dataloader) - 1 if platform.system(
|
599 |
+
) == "Windows" else len(eval_dataloader)
|
600 |
+
char_center = dict()
|
601 |
+
for idx, batch in enumerate(eval_dataloader):
|
602 |
+
if idx >= max_iter:
|
603 |
+
break
|
604 |
+
images = batch[0]
|
605 |
+
start = time.time()
|
606 |
+
preds = model(images)
|
607 |
+
|
608 |
+
batch = [item.numpy() for item in batch]
|
609 |
+
# Obtain usable results from post-processing methods
|
610 |
+
post_result = post_process_class(preds, batch[1])
|
611 |
+
|
612 |
+
#update char_center
|
613 |
+
char_center = update_center(char_center, post_result, preds)
|
614 |
+
pbar.update(1)
|
615 |
+
|
616 |
+
pbar.close()
|
617 |
+
for key in char_center.keys():
|
618 |
+
char_center[key] = char_center[key][0]
|
619 |
+
return char_center
|
620 |
+
|
621 |
+
|
622 |
+
def preprocess(is_train=False):
|
623 |
+
FLAGS = ArgsParser().parse_args()
|
624 |
+
profiler_options = FLAGS.profiler_options
|
625 |
+
config = load_config(FLAGS.config)
|
626 |
+
config = merge_config(config, FLAGS.opt)
|
627 |
+
profile_dic = {"profiler_options": FLAGS.profiler_options}
|
628 |
+
config = merge_config(config, profile_dic)
|
629 |
+
|
630 |
+
if is_train:
|
631 |
+
# save_config
|
632 |
+
save_model_dir = config['Global']['save_model_dir']
|
633 |
+
os.makedirs(save_model_dir, exist_ok=True)
|
634 |
+
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
635 |
+
yaml.dump(
|
636 |
+
dict(config), f, default_flow_style=False, sort_keys=False)
|
637 |
+
log_file = '{}/train.log'.format(save_model_dir)
|
638 |
+
else:
|
639 |
+
log_file = None
|
640 |
+
logger = get_logger(log_file=log_file)
|
641 |
+
|
642 |
+
# check if set use_gpu=True in paddlepaddle cpu version
|
643 |
+
use_gpu = config['Global'].get('use_gpu', False)
|
644 |
+
use_xpu = config['Global'].get('use_xpu', False)
|
645 |
+
use_npu = config['Global'].get('use_npu', False)
|
646 |
+
use_mlu = config['Global'].get('use_mlu', False)
|
647 |
+
|
648 |
+
alg = config['Architecture']['algorithm']
|
649 |
+
assert alg in [
|
650 |
+
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
651 |
+
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
652 |
+
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
653 |
+
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
654 |
+
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN',
|
655 |
+
'Telescope'
|
656 |
+
]
|
657 |
+
|
658 |
+
if use_xpu:
|
659 |
+
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
|
660 |
+
elif use_npu:
|
661 |
+
device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
|
662 |
+
elif use_mlu:
|
663 |
+
device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
|
664 |
+
else:
|
665 |
+
device = 'gpu:{}'.format(dist.ParallelEnv()
|
666 |
+
.dev_id) if use_gpu else 'cpu'
|
667 |
+
check_device(use_gpu, use_xpu, use_npu, use_mlu)
|
668 |
+
|
669 |
+
device = paddle.set_device(device)
|
670 |
+
|
671 |
+
config['Global']['distributed'] = dist.get_world_size() != 1
|
672 |
+
|
673 |
+
loggers = []
|
674 |
+
|
675 |
+
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
676 |
+
save_model_dir = config['Global']['save_model_dir']
|
677 |
+
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
678 |
+
log_writer = VDLLogger(vdl_writer_path)
|
679 |
+
loggers.append(log_writer)
|
680 |
+
if ('use_wandb' in config['Global'] and
|
681 |
+
config['Global']['use_wandb']) or 'wandb' in config:
|
682 |
+
save_dir = config['Global']['save_model_dir']
|
683 |
+
wandb_writer_path = "{}/wandb".format(save_dir)
|
684 |
+
if "wandb" in config:
|
685 |
+
wandb_params = config['wandb']
|
686 |
+
else:
|
687 |
+
wandb_params = dict()
|
688 |
+
wandb_params.update({'save_dir': save_dir})
|
689 |
+
log_writer = WandbLogger(**wandb_params, config=config)
|
690 |
+
loggers.append(log_writer)
|
691 |
+
else:
|
692 |
+
log_writer = None
|
693 |
+
print_dict(config, logger)
|
694 |
+
|
695 |
+
if loggers:
|
696 |
+
log_writer = Loggers(loggers)
|
697 |
+
else:
|
698 |
+
log_writer = None
|
699 |
+
|
700 |
+
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
701 |
+
device))
|
702 |
+
return config, device, logger, log_writer
|
tools/test_hubserving.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
sys.path.append(__dir__)
|
18 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
19 |
+
|
20 |
+
from ppocr.utils.logging import get_logger
|
21 |
+
logger = get_logger()
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import numpy as np
|
25 |
+
import time
|
26 |
+
from PIL import Image
|
27 |
+
from ppocr.utils.utility import get_image_file_list
|
28 |
+
from tools.infer.utility import draw_ocr, draw_boxes, str2bool
|
29 |
+
from ppstructure.utility import draw_structure_result
|
30 |
+
from ppstructure.predict_system import to_excel
|
31 |
+
|
32 |
+
import requests
|
33 |
+
import json
|
34 |
+
import base64
|
35 |
+
|
36 |
+
|
37 |
+
def cv2_to_base64(image):
|
38 |
+
return base64.b64encode(image).decode('utf8')
|
39 |
+
|
40 |
+
|
41 |
+
def draw_server_result(image_file, res):
|
42 |
+
img = cv2.imread(image_file)
|
43 |
+
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
44 |
+
if len(res) == 0:
|
45 |
+
return np.array(image)
|
46 |
+
keys = res[0].keys()
|
47 |
+
if 'text_region' not in keys: # for ocr_rec, draw function is invalid
|
48 |
+
logger.info("draw function is invalid for ocr_rec!")
|
49 |
+
return None
|
50 |
+
elif 'text' not in keys: # for ocr_det
|
51 |
+
logger.info("draw text boxes only!")
|
52 |
+
boxes = []
|
53 |
+
for dno in range(len(res)):
|
54 |
+
boxes.append(res[dno]['text_region'])
|
55 |
+
boxes = np.array(boxes)
|
56 |
+
draw_img = draw_boxes(image, boxes)
|
57 |
+
return draw_img
|
58 |
+
else: # for ocr_system
|
59 |
+
logger.info("draw boxes and texts!")
|
60 |
+
boxes = []
|
61 |
+
texts = []
|
62 |
+
scores = []
|
63 |
+
for dno in range(len(res)):
|
64 |
+
boxes.append(res[dno]['text_region'])
|
65 |
+
texts.append(res[dno]['text'])
|
66 |
+
scores.append(res[dno]['confidence'])
|
67 |
+
boxes = np.array(boxes)
|
68 |
+
scores = np.array(scores)
|
69 |
+
draw_img = draw_ocr(
|
70 |
+
image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
|
71 |
+
return draw_img
|
72 |
+
|
73 |
+
|
74 |
+
def save_structure_res(res, save_folder, image_file):
|
75 |
+
img = cv2.imread(image_file)
|
76 |
+
excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
|
77 |
+
os.makedirs(excel_save_folder, exist_ok=True)
|
78 |
+
# save res
|
79 |
+
with open(
|
80 |
+
os.path.join(excel_save_folder, 'res.txt'), 'w',
|
81 |
+
encoding='utf8') as f:
|
82 |
+
for region in res:
|
83 |
+
if region['type'] == 'Table':
|
84 |
+
excel_path = os.path.join(excel_save_folder,
|
85 |
+
'{}.xlsx'.format(region['bbox']))
|
86 |
+
to_excel(region['res'], excel_path)
|
87 |
+
elif region['type'] == 'Figure':
|
88 |
+
x1, y1, x2, y2 = region['bbox']
|
89 |
+
print(region['bbox'])
|
90 |
+
roi_img = img[y1:y2, x1:x2, :]
|
91 |
+
img_path = os.path.join(excel_save_folder,
|
92 |
+
'{}.jpg'.format(region['bbox']))
|
93 |
+
cv2.imwrite(img_path, roi_img)
|
94 |
+
else:
|
95 |
+
for text_result in region['res']:
|
96 |
+
f.write('{}\n'.format(json.dumps(text_result)))
|
97 |
+
|
98 |
+
|
99 |
+
def main(args):
|
100 |
+
image_file_list = get_image_file_list(args.image_dir)
|
101 |
+
is_visualize = False
|
102 |
+
headers = {"Content-type": "application/json"}
|
103 |
+
cnt = 0
|
104 |
+
total_time = 0
|
105 |
+
for image_file in image_file_list:
|
106 |
+
img = open(image_file, 'rb').read()
|
107 |
+
if img is None:
|
108 |
+
logger.info("error in loading image:{}".format(image_file))
|
109 |
+
continue
|
110 |
+
img_name = os.path.basename(image_file)
|
111 |
+
# seed http request
|
112 |
+
starttime = time.time()
|
113 |
+
data = {'images': [cv2_to_base64(img)]}
|
114 |
+
r = requests.post(
|
115 |
+
url=args.server_url, headers=headers, data=json.dumps(data))
|
116 |
+
elapse = time.time() - starttime
|
117 |
+
total_time += elapse
|
118 |
+
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
119 |
+
res = r.json()["results"][0]
|
120 |
+
logger.info(res)
|
121 |
+
|
122 |
+
if args.visualize:
|
123 |
+
draw_img = None
|
124 |
+
if 'structure_table' in args.server_url:
|
125 |
+
to_excel(res['html'], './{}.xlsx'.format(img_name))
|
126 |
+
elif 'structure_system' in args.server_url:
|
127 |
+
save_structure_res(res['regions'], args.output, image_file)
|
128 |
+
else:
|
129 |
+
draw_img = draw_server_result(image_file, res)
|
130 |
+
if draw_img is not None:
|
131 |
+
if not os.path.exists(args.output):
|
132 |
+
os.makedirs(args.output)
|
133 |
+
cv2.imwrite(
|
134 |
+
os.path.join(args.output, os.path.basename(image_file)),
|
135 |
+
draw_img[:, :, ::-1])
|
136 |
+
logger.info("The visualized image saved in {}".format(
|
137 |
+
os.path.join(args.output, os.path.basename(image_file))))
|
138 |
+
cnt += 1
|
139 |
+
if cnt % 100 == 0:
|
140 |
+
logger.info("{} processed".format(cnt))
|
141 |
+
logger.info("avg time cost: {}".format(float(total_time) / cnt))
|
142 |
+
|
143 |
+
|
144 |
+
def parse_args():
|
145 |
+
import argparse
|
146 |
+
parser = argparse.ArgumentParser(description="args for hub serving")
|
147 |
+
parser.add_argument("--server_url", type=str, required=True)
|
148 |
+
parser.add_argument("--image_dir", type=str, required=True)
|
149 |
+
parser.add_argument("--visualize", type=str2bool, default=False)
|
150 |
+
parser.add_argument("--output", type=str, default='./hubserving_result')
|
151 |
+
args = parser.parse_args()
|
152 |
+
return args
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
args = parse_args()
|
157 |
+
main(args)
|
tools/train.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
|
22 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
23 |
+
sys.path.append(__dir__)
|
24 |
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
25 |
+
|
26 |
+
import yaml
|
27 |
+
import paddle
|
28 |
+
import paddle.distributed as dist
|
29 |
+
|
30 |
+
from ppocr.data import build_dataloader
|
31 |
+
from ppocr.modeling.architectures import build_model
|
32 |
+
from ppocr.losses import build_loss
|
33 |
+
from ppocr.optimizer import build_optimizer
|
34 |
+
from ppocr.postprocess import build_post_process
|
35 |
+
from ppocr.metrics import build_metric
|
36 |
+
from ppocr.utils.save_load import load_model
|
37 |
+
from ppocr.utils.utility import set_seed
|
38 |
+
from ppocr.modeling.architectures import apply_to_static
|
39 |
+
import tools.program as program
|
40 |
+
|
41 |
+
dist.get_world_size()
|
42 |
+
|
43 |
+
|
44 |
+
def main(config, device, logger, vdl_writer):
|
45 |
+
# init dist environment
|
46 |
+
if config['Global']['distributed']:
|
47 |
+
dist.init_parallel_env()
|
48 |
+
|
49 |
+
global_config = config['Global']
|
50 |
+
|
51 |
+
# build dataloader
|
52 |
+
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
53 |
+
if len(train_dataloader) == 0:
|
54 |
+
logger.error(
|
55 |
+
"No Images in train dataset, please ensure\n" +
|
56 |
+
"\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
|
57 |
+
+
|
58 |
+
"\t2. The annotation file and path in the configuration file are provided normally."
|
59 |
+
)
|
60 |
+
return
|
61 |
+
|
62 |
+
if config['Eval']:
|
63 |
+
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
64 |
+
else:
|
65 |
+
valid_dataloader = None
|
66 |
+
|
67 |
+
# build post process
|
68 |
+
post_process_class = build_post_process(config['PostProcess'],
|
69 |
+
global_config)
|
70 |
+
|
71 |
+
# build model
|
72 |
+
# for rec algorithm
|
73 |
+
if hasattr(post_process_class, 'character'):
|
74 |
+
char_num = len(getattr(post_process_class, 'character'))
|
75 |
+
if config['Architecture']["algorithm"] in ["Distillation",
|
76 |
+
]: # distillation model
|
77 |
+
for key in config['Architecture']["Models"]:
|
78 |
+
if config['Architecture']['Models'][key]['Head'][
|
79 |
+
'name'] == 'MultiHead': # for multi head
|
80 |
+
if config['PostProcess'][
|
81 |
+
'name'] == 'DistillationSARLabelDecode':
|
82 |
+
char_num = char_num - 2
|
83 |
+
# update SARLoss params
|
84 |
+
assert list(config['Loss']['loss_config_list'][-1].keys())[
|
85 |
+
0] == 'DistillationSARLoss'
|
86 |
+
config['Loss']['loss_config_list'][-1][
|
87 |
+
'DistillationSARLoss']['ignore_index'] = char_num + 1
|
88 |
+
out_channels_list = {}
|
89 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
90 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
91 |
+
config['Architecture']['Models'][key]['Head'][
|
92 |
+
'out_channels_list'] = out_channels_list
|
93 |
+
else:
|
94 |
+
config['Architecture']["Models"][key]["Head"][
|
95 |
+
'out_channels'] = char_num
|
96 |
+
elif config['Architecture']['Head'][
|
97 |
+
'name'] == 'MultiHead': # for multi head
|
98 |
+
if config['PostProcess']['name'] == 'SARLabelDecode':
|
99 |
+
char_num = char_num - 2
|
100 |
+
# update SARLoss params
|
101 |
+
assert list(config['Loss']['loss_config_list'][1].keys())[
|
102 |
+
0] == 'SARLoss'
|
103 |
+
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
|
104 |
+
config['Loss']['loss_config_list'][1]['SARLoss'] = {
|
105 |
+
'ignore_index': char_num + 1
|
106 |
+
}
|
107 |
+
else:
|
108 |
+
config['Loss']['loss_config_list'][1]['SARLoss'][
|
109 |
+
'ignore_index'] = char_num + 1
|
110 |
+
out_channels_list = {}
|
111 |
+
out_channels_list['CTCLabelDecode'] = char_num
|
112 |
+
out_channels_list['SARLabelDecode'] = char_num + 2
|
113 |
+
config['Architecture']['Head'][
|
114 |
+
'out_channels_list'] = out_channels_list
|
115 |
+
else: # base rec model
|
116 |
+
config['Architecture']["Head"]['out_channels'] = char_num
|
117 |
+
|
118 |
+
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
|
119 |
+
config['Loss']['ignore_index'] = char_num - 1
|
120 |
+
|
121 |
+
model = build_model(config['Architecture'])
|
122 |
+
|
123 |
+
use_sync_bn = config["Global"].get("use_sync_bn", False)
|
124 |
+
if use_sync_bn:
|
125 |
+
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
126 |
+
logger.info('convert_sync_batchnorm')
|
127 |
+
|
128 |
+
model = apply_to_static(model, config, logger)
|
129 |
+
|
130 |
+
# build loss
|
131 |
+
loss_class = build_loss(config['Loss'])
|
132 |
+
|
133 |
+
# build optim
|
134 |
+
optimizer, lr_scheduler = build_optimizer(
|
135 |
+
config['Optimizer'],
|
136 |
+
epochs=config['Global']['epoch_num'],
|
137 |
+
step_each_epoch=len(train_dataloader),
|
138 |
+
model=model)
|
139 |
+
|
140 |
+
# build metric
|
141 |
+
eval_class = build_metric(config['Metric'])
|
142 |
+
|
143 |
+
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
144 |
+
if valid_dataloader is not None:
|
145 |
+
logger.info('valid dataloader has {} iters'.format(
|
146 |
+
len(valid_dataloader)))
|
147 |
+
|
148 |
+
use_amp = config["Global"].get("use_amp", False)
|
149 |
+
amp_level = config["Global"].get("amp_level", 'O2')
|
150 |
+
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
|
151 |
+
if use_amp:
|
152 |
+
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
|
153 |
+
if paddle.is_compiled_with_cuda():
|
154 |
+
AMP_RELATED_FLAGS_SETTING.update({
|
155 |
+
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
|
156 |
+
})
|
157 |
+
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
158 |
+
scale_loss = config["Global"].get("scale_loss", 1.0)
|
159 |
+
use_dynamic_loss_scaling = config["Global"].get(
|
160 |
+
"use_dynamic_loss_scaling", False)
|
161 |
+
scaler = paddle.amp.GradScaler(
|
162 |
+
init_loss_scaling=scale_loss,
|
163 |
+
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
164 |
+
if amp_level == "O2":
|
165 |
+
model, optimizer = paddle.amp.decorate(
|
166 |
+
models=model,
|
167 |
+
optimizers=optimizer,
|
168 |
+
level=amp_level,
|
169 |
+
master_weight=True)
|
170 |
+
else:
|
171 |
+
scaler = None
|
172 |
+
|
173 |
+
# load pretrain model
|
174 |
+
pre_best_model_dict = load_model(config, model, optimizer,
|
175 |
+
config['Architecture']["model_type"])
|
176 |
+
|
177 |
+
if config['Global']['distributed']:
|
178 |
+
model = paddle.DataParallel(model)
|
179 |
+
# start train
|
180 |
+
program.train(config, train_dataloader, valid_dataloader, device, model,
|
181 |
+
loss_class, optimizer, lr_scheduler, post_process_class,
|
182 |
+
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
|
183 |
+
amp_level, amp_custom_black_list)
|
184 |
+
|
185 |
+
|
186 |
+
def test_reader(config, device, logger):
|
187 |
+
loader = build_dataloader(config, 'Train', device, logger)
|
188 |
+
import time
|
189 |
+
starttime = time.time()
|
190 |
+
count = 0
|
191 |
+
try:
|
192 |
+
for data in loader():
|
193 |
+
count += 1
|
194 |
+
if count % 1 == 0:
|
195 |
+
batch_time = time.time() - starttime
|
196 |
+
starttime = time.time()
|
197 |
+
logger.info("reader: {}, {}, {}".format(
|
198 |
+
count, len(data[0]), batch_time))
|
199 |
+
except Exception as e:
|
200 |
+
logger.info(e)
|
201 |
+
logger.info("finish reader: {}, Success!".format(count))
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
206 |
+
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
|
207 |
+
set_seed(seed)
|
208 |
+
main(config, device, logger, vdl_writer)
|
209 |
+
# test_reader(config, device, logger)
|