SWHL commited on
Commit
38a5dfd
·
1 Parent(s): 809becd

First commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .vscode
2
+
3
+ *.pyc
4
+
5
+ __pycache__/
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ import os
5
+ os.system('pip install -r requirements.txt')
6
+
7
+ import cv2
8
+ import gradio as gr
9
+
10
+ from ctrnet_infer import CTRNetInfer
11
+
12
+
13
+ def inference(img):
14
+ img_path = img.name
15
+ img = cv2.imread(img_path)
16
+ pred = ctrnet(img)
17
+ pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
18
+ return pred
19
+
20
+
21
+ model_path = 'models/CTRNet_G.onnx'
22
+ ctrnet = CTRNetInfer(model_path)
23
+
24
+ title = 'CTRNet Demo'
25
+ description = '''This is the demo for the paper “Don't Forget Me: Accurate Background Recovery for Text Removal via Modeling Local-Global Context”. Github Repo: https://github.com/lcy0604/CTRNet'''
26
+ css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
27
+ examples = [['images/1.jpg'], ['images/2.jpg'], ['images/4.jpg']]
28
+
29
+ gr.Interface(
30
+ inference,
31
+ inputs=[
32
+ gr.inputs.Image(type='file', label='Input'),
33
+ ],
34
+ outputs=[
35
+ gr.outputs.Image(type='file', label='Output_image'),
36
+ ],
37
+ title=title,
38
+ description=description,
39
+ examples=examples,
40
+ css=css,
41
+ allow_flagging='never',
42
+ enable_queue=True
43
+ ).launch(debug=True, enable_queue=True)
ctrnet_infer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ import copy
5
+ import time
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import pyclipper
10
+ from onnxruntime import InferenceSession
11
+ from shapely.geometry import Polygon
12
+
13
+ from rapid_ch_det import TextDetector
14
+
15
+
16
+ class SimpleDataset():
17
+ def __call__(self, img: np.ndarray, bboxes: np.ndarray):
18
+ '''
19
+ bboxes: (N, 4, 2)
20
+ '''
21
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
22
+
23
+ gt_instance = np.zeros(img.shape[:2], dtype='uint8')
24
+ for i in range(len(bboxes)):
25
+ cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
26
+ gt_text = gt_instance.copy()
27
+ gt_text[gt_text > 0] = 1
28
+ gt_text = gt_text[None, None, ...].astype(np.float32)
29
+
30
+ canvas, shrink_mask, mask_ori = self.get_seg_map(img, bboxes)
31
+ soft_mask = canvas + mask_ori
32
+ index_mask = np.where(soft_mask > 1)
33
+ soft_mask[index_mask] = 1
34
+ soft_mask = soft_mask[None, None, ...].astype(np.float32)
35
+
36
+ img = np.transpose(img, (2, 0, 1)).astype(np.float32) / 255.0
37
+ img = img[None, ...]
38
+ structure_im = copy.deepcopy(img)
39
+ return img, structure_im, gt_text, soft_mask
40
+
41
+ def draw_border_map(self, polygon, canvas, mask_ori, mask):
42
+ polygon = np.array(polygon)
43
+ assert polygon.ndim == 2
44
+ assert polygon.shape[1] == 2
45
+
46
+ ### shrink box ###
47
+ polygon_shape = Polygon(polygon)
48
+ distance = polygon_shape.area * \
49
+ (1 - np.power(0.95, 2)) / polygon_shape.length
50
+ subject = [tuple(l) for l in polygon]
51
+ padding = pyclipper.PyclipperOffset()
52
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
53
+ padded_polygon = np.array(padding.Execute(-distance)[0])
54
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
55
+ ### shrink box ###
56
+
57
+ cv2.fillPoly(mask_ori, [polygon.astype(np.int32)], 1.0)
58
+
59
+ polygon = padded_polygon
60
+ polygon_shape = Polygon(padded_polygon)
61
+ distance = polygon_shape.area * \
62
+ (1 - np.power(0.4, 2)) / polygon_shape.length
63
+ subject = [tuple(l) for l in polygon]
64
+ padding = pyclipper.PyclipperOffset()
65
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
66
+ padded_polygon = np.array(padding.Execute(distance)[0])
67
+
68
+ xmin = padded_polygon[:, 0].min()
69
+ xmax = padded_polygon[:, 0].max()
70
+ ymin = padded_polygon[:, 1].min()
71
+ ymax = padded_polygon[:, 1].max()
72
+ width = xmax - xmin + 1
73
+ height = ymax - ymin + 1
74
+
75
+ polygon[:, 0] = polygon[:, 0] - xmin
76
+ polygon[:, 1] = polygon[:, 1] - ymin
77
+
78
+ xs = np.broadcast_to(
79
+ np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
80
+ ys = np.broadcast_to(
81
+ np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
82
+
83
+ distance_map = np.zeros(
84
+ (polygon.shape[0], height, width), dtype=np.float32)
85
+ for i in range(polygon.shape[0]):
86
+ j = (i + 1) % polygon.shape[0]
87
+ # import pdb;pdb.set_trace()
88
+ absolute_distance = self.coumpute_distance(xs, ys, polygon[i], polygon[j])
89
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
90
+ distance_map = distance_map.min(axis=0)
91
+
92
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
93
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
94
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
95
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
96
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
97
+ 1 - distance_map[
98
+ ymin_valid-ymin:ymax_valid-ymax+height,
99
+ xmin_valid-xmin:xmax_valid-xmax+width],
100
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
101
+
102
+ @staticmethod
103
+ def coumpute_distance(xs, ys, point_1, point_2):
104
+ '''
105
+ compute the distance from point to a line
106
+ ys: coordinates in the first axis
107
+ xs: coordinates in the second axis
108
+ point_1, point_2: (x, y), the end of the line
109
+ '''
110
+ height, width = xs.shape[:2]
111
+ square_distance_1 = np.square(
112
+ xs - point_1[0]) + np.square(ys - point_1[1])
113
+ square_distance_2 = np.square(
114
+ xs - point_2[0]) + np.square(ys - point_2[1])
115
+ square_distance = np.square(
116
+ point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
117
+
118
+ cosin = (square_distance - square_distance_1 - square_distance_2) / \
119
+ (2 * np.sqrt(square_distance_1 * square_distance_2) + 1e-50)
120
+ square_sin = 1 - np.square(cosin)
121
+ square_sin = np.nan_to_num(square_sin)
122
+ result = np.sqrt(square_distance_1 * square_distance_2 *
123
+ square_sin / square_distance)
124
+
125
+ result[cosin < 0] = np.sqrt(np.fmin(
126
+ square_distance_1, square_distance_2))[cosin < 0]
127
+ # extend_line(point_1, point_2, result)
128
+ return result
129
+
130
+ def get_seg_map(self, img, label):
131
+ canvas = np.zeros(img.shape[:2], dtype=np.float32)
132
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
133
+ mask_ori = np.zeros(img.shape[:2], dtype=np.float32)
134
+ polygons = label
135
+
136
+ for i in range(len(polygons)):
137
+ self.draw_border_map(polygons[i], canvas, mask_ori, mask=mask)
138
+ return canvas, mask, mask_ori
139
+
140
+
141
+ class CTRNetInfer():
142
+ def __init__(self, model_path) -> None:
143
+ self.session = InferenceSession(model_path,
144
+ providers=['CPUExecutionProvider'])
145
+ self.dataset = SimpleDataset()
146
+ self.text_det = TextDetector()
147
+ self.input_shape = (512, 512)
148
+
149
+ def __call__(self, ori_img):
150
+ ori_img_shape = ori_img.shape[:2]
151
+ bboxes = self.text_det(ori_img)[0].astype(np.int64)
152
+
153
+ # resize img 到512x512
154
+ resize_img = cv2.resize(ori_img, self.input_shape,
155
+ interpolation=cv2.INTER_LINEAR)
156
+ resize_bboxes = self.get_resized_points(bboxes,
157
+ ori_img_shape,
158
+ self.input_shape)
159
+
160
+ img, structure_im, gt_text, soft_mask = self.dataset(
161
+ resize_img, resize_bboxes)
162
+ input_dict = {
163
+ 'input': img,
164
+ 'gt_text': gt_text,
165
+ 'soft_mask': soft_mask,
166
+ 'structure_im': structure_im
167
+ }
168
+ prediction = self.session.run(None, input_dict)[3]
169
+
170
+ withMask_prediction = prediction * soft_mask + img * (1 - soft_mask)
171
+ withMask_prediction = np.transpose(withMask_prediction, (0, 2, 3, 1)) * 255
172
+ withMask_prediction = withMask_prediction.squeeze().astype(np.uint8)
173
+ withMask_prediction = cv2.cvtColor(withMask_prediction,
174
+ cv2.COLOR_BGR2RGB)
175
+ ori_pred = cv2.resize(withMask_prediction, ori_img_shape[::-1],
176
+ interpolation=cv2.INTER_LINEAR)
177
+ return ori_pred
178
+
179
+ @staticmethod
180
+ def get_resized_points(cur_points, cur_shape, new_shape):
181
+ cur_points = np.array(cur_points)
182
+
183
+ ratio_x = cur_shape[0] / new_shape[0]
184
+ ratio_y = cur_shape[1] / new_shape[1]
185
+ cur_points[:, :, 0] = cur_points[:, :, 0] / ratio_x
186
+ cur_points[:, :, 1] = cur_points[:, :, 1] / ratio_y
187
+ return cur_points.astype(np.int64)
188
+
189
+
190
+ if __name__ == '__main__':
191
+ model_path = 'CTRNet_G.onnx'
192
+ ctrnet = CTRNetInfer(model_path)
193
+
194
+ img_path = 'images/1.jpg'
195
+ ori_img = cv2.imread(img_path)
196
+
197
+ s = time.time()
198
+ pred = ctrnet(ori_img)
199
+ print(f'elapse: {time.time() - s}')
200
+
201
+ cv2.imwrite('pred_result.jpg', pred)
images/1.jpg ADDED
images/2.jpg ADDED
images/4.jpg ADDED
models/CTRNet_G.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15d46cec531574c5afef5f27f287f0ccf62a911749089f7cfcbf760226a3eda8
3
+ size 842447752
rapid_ch_det/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: liekkaskono@163.com
4
+ from .text_detect import TextDetector
rapid_ch_det/config.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_path: models/ch_PP-OCRv3_det_infer.onnx
2
+
3
+ use_cuda: false
4
+ CUDAExecutionProvider:
5
+ device_id: 0
6
+ arena_extend_strategy: kNextPowerOfTwo
7
+ cudnn_conv_algo_search: EXHAUSTIVE
8
+ do_copy_in_default_stream: true
9
+
10
+ pre_process:
11
+ DetResizeForTest:
12
+ limit_side_len: 736
13
+ limit_type: min
14
+ NormalizeImage:
15
+ std: [0.229, 0.224, 0.225]
16
+ mean: [0.485, 0.456, 0.406]
17
+ scale: 1./255.
18
+ order: hwc
19
+ ToCHWImage:
20
+ KeepKeys:
21
+ keep_keys: ['image', 'shape']
22
+
23
+ post_process:
24
+ thresh: 0.3
25
+ box_thresh: 0.5
26
+ max_candidates: 1000
27
+ unclip_ratio: 1.6
28
+ use_dilation: true
29
+ score_mode: "fast"
rapid_ch_det/models/ch_PP-OCRv3_det_infer.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3439588c030faea393a54515f51e983d8e155b19a2e8aba7891934c1cf0de526
3
+ size 2432880
rapid_ch_det/text_detect.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # -*- encoding: utf-8 -*-
15
+ # @Author: SWHL
16
+ # @Contact: liekkaskono@163.com
17
+ import argparse
18
+ import time
19
+
20
+ import cv2
21
+ from pathlib import Path
22
+ import numpy as np
23
+
24
+ try:
25
+ from .utils import (DBPostProcess, create_operators,
26
+ transform, read_yaml, OrtInferSession)
27
+ except:
28
+ from utils import (DBPostProcess, create_operators,
29
+ transform, read_yaml, OrtInferSession)
30
+
31
+ root_dir = Path(__file__).resolve().parent
32
+
33
+
34
+ class TextDetector():
35
+ def __init__(self, config=str(root_dir / 'config.yaml')):
36
+ if isinstance(config, str):
37
+ config = read_yaml(config)
38
+ config['model_path'] = str(root_dir / config['model_path'])
39
+
40
+ self.preprocess_op = create_operators(config['pre_process'])
41
+ self.postprocess_op = DBPostProcess(**config['post_process'])
42
+
43
+ session_instance = OrtInferSession(config)
44
+ self.session = session_instance.session
45
+ self.input_name = session_instance.get_input_name()
46
+
47
+ def __call__(self, img):
48
+ if img is None:
49
+ raise ValueError('img is None')
50
+
51
+ ori_im_shape = img.shape[:2]
52
+
53
+ data = {'image': img}
54
+ data = transform(data, self.preprocess_op)
55
+ img, shape_list = data
56
+ if img is None:
57
+ return None, 0
58
+
59
+ img = np.expand_dims(img, axis=0).astype(np.float32)
60
+ shape_list = np.expand_dims(shape_list, axis=0)
61
+
62
+ starttime = time.time()
63
+ preds = self.session.run(None, {self.input_name: img})
64
+
65
+ post_result = self.postprocess_op(preds[0], shape_list)
66
+
67
+ dt_boxes = post_result[0]['points']
68
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
69
+ elapse = time.time() - starttime
70
+ return dt_boxes, elapse
71
+
72
+ def order_points_clockwise(self, pts):
73
+ """
74
+ reference from:
75
+ https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
76
+ sort the points based on their x-coordinates
77
+ """
78
+ xSorted = pts[np.argsort(pts[:, 0]), :]
79
+
80
+ # grab the left-most and right-most points from the sorted
81
+ # x-roodinate points
82
+ leftMost = xSorted[:2, :]
83
+ rightMost = xSorted[2:, :]
84
+
85
+ # now, sort the left-most coordinates according to their
86
+ # y-coordinates so we can grab the top-left and bottom-left
87
+ # points, respectively
88
+ leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
89
+ (tl, bl) = leftMost
90
+
91
+ rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
92
+ (tr, br) = rightMost
93
+
94
+ rect = np.array([tl, tr, br, bl], dtype="float32")
95
+ return rect
96
+
97
+ def clip_det_res(self, points, img_height, img_width):
98
+ for pno in range(points.shape[0]):
99
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
100
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
101
+ return points
102
+
103
+ def filter_tag_det_res(self, dt_boxes, image_shape):
104
+ img_height, img_width = image_shape[:2]
105
+ dt_boxes_new = []
106
+ for box in dt_boxes:
107
+ box = self.order_points_clockwise(box)
108
+ box = self.clip_det_res(box, img_height, img_width)
109
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
110
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
111
+ if rect_width <= 3 or rect_height <= 3:
112
+ continue
113
+ dt_boxes_new.append(box)
114
+ dt_boxes = np.array(dt_boxes_new)
115
+ return dt_boxes
116
+
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser()
120
+ parser.add_argument('--config_path', type=str, default='config.yaml')
121
+ parser.add_argument('--image_path', type=str, default=None)
122
+ args = parser.parse_args()
123
+
124
+ config = read_yaml(args.config_path)
125
+
126
+ text_detector = TextDetector(config)
127
+
128
+ img = cv2.imread(args.image_path)
129
+ dt_boxes, elapse = text_detector(img)
130
+
131
+ from utils import draw_text_det_res
132
+ src_im = draw_text_det_res(dt_boxes, args.image_path)
133
+ cv2.imwrite('det_results.jpg', src_im)
134
+ print('The det_results.jpg has been saved in the current directory.')
rapid_ch_det/utils.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ # -*- encoding: utf-8 -*-
17
+ # @Author: SWHL
18
+ # @Contact: liekkaskono@163.com
19
+ import sys
20
+ import warnings
21
+ from pathlib import Path
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import pyclipper
26
+ import six
27
+ import yaml
28
+ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
29
+ SessionOptions, get_available_providers, get_device)
30
+ from shapely.geometry import Polygon
31
+
32
+ root_dir = Path(__file__).resolve().parent.parent
33
+
34
+
35
+ class OrtInferSession():
36
+ def __init__(self, config):
37
+ sess_opt = SessionOptions()
38
+ sess_opt.log_severity_level = 4
39
+ sess_opt.enable_cpu_mem_arena = False
40
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
41
+
42
+ cuda_ep = 'CUDAExecutionProvider'
43
+ cpu_ep = 'CPUExecutionProvider'
44
+ cpu_provider_options = {
45
+ "arena_extend_strategy": "kSameAsRequested",
46
+ }
47
+
48
+ EP_list = []
49
+ if config['use_cuda'] and get_device() == 'GPU' \
50
+ and cuda_ep in get_available_providers():
51
+ EP_list = [(cuda_ep, config[cuda_ep])]
52
+ EP_list.append((cpu_ep, cpu_provider_options))
53
+
54
+ config['model_path'] = str(root_dir / config['model_path'])
55
+ self._verify_model(config['model_path'])
56
+ self.session = InferenceSession(config['model_path'],
57
+ sess_options=sess_opt,
58
+ providers=EP_list)
59
+
60
+ if config['use_cuda'] and cuda_ep not in self.session.get_providers():
61
+ warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
62
+ 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
63
+ 'you can check their relations from the offical web site: '
64
+ 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
65
+ RuntimeWarning)
66
+
67
+ def get_input_name(self, input_idx=0):
68
+ return self.session.get_inputs()[input_idx].name
69
+
70
+ def get_output_name(self, output_idx=0):
71
+ return self.session.get_outputs()[output_idx].name
72
+
73
+ @staticmethod
74
+ def _verify_model(model_path):
75
+ model_path = Path(model_path)
76
+ if not model_path.exists():
77
+ raise FileNotFoundError(f'{model_path} does not exists.')
78
+ if not model_path.is_file():
79
+ raise FileExistsError(f'{model_path} is not a file.')
80
+
81
+
82
+ def read_yaml(yaml_path):
83
+ with open(yaml_path, 'rb') as f:
84
+ data = yaml.load(f, Loader=yaml.Loader)
85
+ return data
86
+
87
+
88
+ class DecodeImage():
89
+ """ decode image """
90
+
91
+ def __init__(self, img_mode='RGB', channel_first=False):
92
+ self.img_mode = img_mode
93
+ self.channel_first = channel_first
94
+
95
+ def __call__(self, data):
96
+ img = data['image']
97
+ if six.PY2:
98
+ assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage"
99
+ else:
100
+ assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
101
+
102
+ img = np.frombuffer(img, dtype='uint8')
103
+ img = cv2.imdecode(img, 1)
104
+ if img is None:
105
+ return None
106
+
107
+ if self.img_mode == 'GRAY':
108
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
109
+ elif self.img_mode == 'RGB':
110
+ assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]'
111
+ img = img[:, :, ::-1]
112
+
113
+ if self.channel_first:
114
+ img = img.transpose((2, 0, 1))
115
+ data['image'] = img
116
+ return data
117
+
118
+
119
+ class NormalizeImage():
120
+ """ normalize image such as substract mean, divide std"""
121
+
122
+ def __init__(self, scale=None, mean=None, std=None, order='chw'):
123
+ if isinstance(scale, str):
124
+ scale = eval(scale)
125
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
126
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
127
+ std = std if std is not None else [0.229, 0.224, 0.225]
128
+
129
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
130
+ self.mean = np.array(mean).reshape(shape).astype('float32')
131
+ self.std = np.array(std).reshape(shape).astype('float32')
132
+
133
+ def __call__(self, data):
134
+ img = np.array(data['image']).astype(np.float32)
135
+ data['image'] = (img * self.scale - self.mean) / self.std
136
+ return data
137
+
138
+
139
+ class ToCHWImage():
140
+ """ convert hwc image to chw image"""
141
+ def __init__(self):
142
+ pass
143
+
144
+ def __call__(self, data):
145
+ img = np.array(data['image'])
146
+ data['image'] = img.transpose((2, 0, 1))
147
+ return data
148
+
149
+
150
+ class KeepKeys():
151
+ def __init__(self, keep_keys):
152
+ self.keep_keys = keep_keys
153
+
154
+ def __call__(self, data):
155
+ data_list = []
156
+ for key in self.keep_keys:
157
+ data_list.append(data[key])
158
+ return data_list
159
+
160
+
161
+ class DetResizeForTest():
162
+ def __init__(self, **kwargs):
163
+ super(DetResizeForTest, self).__init__()
164
+ self.resize_type = 0
165
+ if 'image_shape' in kwargs:
166
+ self.image_shape = kwargs['image_shape']
167
+ self.resize_type = 1
168
+ elif 'limit_side_len' in kwargs:
169
+ self.limit_side_len = kwargs.get('limit_side_len', 736)
170
+ self.limit_type = kwargs.get('limit_type', 'min')
171
+
172
+ if 'resize_long' in kwargs:
173
+ self.resize_type = 2
174
+ self.resize_long = kwargs.get('resize_long', 960)
175
+ else:
176
+ self.limit_side_len = kwargs.get('limit_side_len', 736)
177
+ self.limit_type = kwargs.get('limit_type', 'min')
178
+
179
+ def __call__(self, data):
180
+ img = data['image']
181
+ src_h, src_w = img.shape[:2]
182
+
183
+ if self.resize_type == 0:
184
+ # img, shape = self.resize_image_type0(img)
185
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
186
+ elif self.resize_type == 2:
187
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
188
+ else:
189
+ # img, shape = self.resize_image_type1(img)
190
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
191
+ data['image'] = img
192
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
193
+ return data
194
+
195
+ def resize_image_type1(self, img):
196
+ resize_h, resize_w = self.image_shape
197
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
198
+ ratio_h = float(resize_h) / ori_h
199
+ ratio_w = float(resize_w) / ori_w
200
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
201
+ # return img, np.array([ori_h, ori_w])
202
+ return img, [ratio_h, ratio_w]
203
+
204
+ def resize_image_type0(self, img):
205
+ """
206
+ resize image to a size multiple of 32 which is required by the network
207
+ args:
208
+ img(array): array with shape [h, w, c]
209
+ return(tuple):
210
+ img, (ratio_h, ratio_w)
211
+ """
212
+ limit_side_len = self.limit_side_len
213
+ h, w = img.shape[:2]
214
+
215
+ # limit the max side
216
+ if self.limit_type == 'max':
217
+ if max(h, w) > limit_side_len:
218
+ if h > w:
219
+ ratio = float(limit_side_len) / h
220
+ else:
221
+ ratio = float(limit_side_len) / w
222
+ else:
223
+ ratio = 1.
224
+ else:
225
+ if min(h, w) < limit_side_len:
226
+ if h < w:
227
+ ratio = float(limit_side_len) / h
228
+ else:
229
+ ratio = float(limit_side_len) / w
230
+ else:
231
+ ratio = 1.
232
+ resize_h = int(h * ratio)
233
+ resize_w = int(w * ratio)
234
+
235
+ resize_h = int(round(resize_h / 32) * 32)
236
+ resize_w = int(round(resize_w / 32) * 32)
237
+
238
+ try:
239
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
240
+ return None, (None, None)
241
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
242
+ except:
243
+ print(img.shape, resize_w, resize_h)
244
+ sys.exit(0)
245
+ ratio_h = resize_h / float(h)
246
+ ratio_w = resize_w / float(w)
247
+ return img, [ratio_h, ratio_w]
248
+
249
+ def resize_image_type2(self, img):
250
+ h, w = img.shape[:2]
251
+
252
+ resize_w = w
253
+ resize_h = h
254
+
255
+ # Fix the longer side
256
+ if resize_h > resize_w:
257
+ ratio = float(self.resize_long) / resize_h
258
+ else:
259
+ ratio = float(self.resize_long) / resize_w
260
+
261
+ resize_h = int(resize_h * ratio)
262
+ resize_w = int(resize_w * ratio)
263
+
264
+ max_stride = 128
265
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
266
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
267
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
268
+ ratio_h = resize_h / float(h)
269
+ ratio_w = resize_w / float(w)
270
+
271
+ return img, [ratio_h, ratio_w]
272
+
273
+
274
+ def transform(data, ops=None):
275
+ """ transform """
276
+ if ops is None:
277
+ ops = []
278
+
279
+ for op in ops:
280
+ data = op(data)
281
+ if data is None:
282
+ return None
283
+ return data
284
+
285
+
286
+ def create_operators(op_param_dict):
287
+ """
288
+ create operators based on the config
289
+ """
290
+ ops = []
291
+ for op_name, param in op_param_dict.items():
292
+ if param is None:
293
+ param = {}
294
+ op = eval(op_name)(**param)
295
+ ops.append(op)
296
+ return ops
297
+
298
+
299
+ def draw_text_det_res(dt_boxes, img_path):
300
+ src_im = cv2.imread(img_path)
301
+ for box in dt_boxes:
302
+ box = np.array(box).astype(np.int32).reshape(-1, 2)
303
+ cv2.polylines(src_im, [box], True,
304
+ color=(255, 255, 0), thickness=2)
305
+ return src_im
306
+
307
+
308
+ class DBPostProcess():
309
+ """The post process for Differentiable Binarization (DB)."""
310
+
311
+ def __init__(self,
312
+ thresh=0.3,
313
+ box_thresh=0.7,
314
+ max_candidates=1000,
315
+ unclip_ratio=2.0,
316
+ score_mode="fast",
317
+ use_dilation=False):
318
+ self.thresh = thresh
319
+ self.box_thresh = box_thresh
320
+ self.max_candidates = max_candidates
321
+ self.unclip_ratio = unclip_ratio
322
+ self.min_size = 3
323
+ self.score_mode = score_mode
324
+
325
+ if use_dilation:
326
+ self.dilation_kernel = np.array([[1, 1], [1, 1]])
327
+ else:
328
+ self.dilation_kernel = None
329
+
330
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
331
+ '''
332
+ _bitmap: single map with shape (1, H, W),
333
+ whose values are binarized as {0, 1}
334
+ '''
335
+
336
+ bitmap = _bitmap
337
+ height, width = bitmap.shape
338
+
339
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
340
+ cv2.CHAIN_APPROX_SIMPLE)
341
+ if len(outs) == 3:
342
+ img, contours, _ = outs[0], outs[1], outs[2]
343
+ elif len(outs) == 2:
344
+ contours, _ = outs[0], outs[1]
345
+
346
+ num_contours = min(len(contours), self.max_candidates)
347
+
348
+ boxes = []
349
+ scores = []
350
+ for index in range(num_contours):
351
+ contour = contours[index]
352
+ points, sside = self.get_mini_boxes(contour)
353
+ if sside < self.min_size:
354
+ continue
355
+ points = np.array(points)
356
+ if self.score_mode == "fast":
357
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
358
+ else:
359
+ score = self.box_score_slow(pred, contour)
360
+ if self.box_thresh > score:
361
+ continue
362
+
363
+ box = self.unclip(points).reshape(-1, 1, 2)
364
+ box, sside = self.get_mini_boxes(box)
365
+ if sside < self.min_size + 2:
366
+ continue
367
+ box = np.array(box)
368
+
369
+ box[:, 0] = np.clip(
370
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
371
+ box[:, 1] = np.clip(
372
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
373
+ boxes.append(box.astype(np.int16))
374
+ scores.append(score)
375
+ return np.array(boxes, dtype=np.int16), scores
376
+
377
+ def unclip(self, box):
378
+ unclip_ratio = self.unclip_ratio
379
+ poly = Polygon(box)
380
+ distance = poly.area * unclip_ratio / poly.length
381
+ offset = pyclipper.PyclipperOffset()
382
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
383
+ expanded = np.array(offset.Execute(distance))
384
+ return expanded
385
+
386
+ def get_mini_boxes(self, contour):
387
+ bounding_box = cv2.minAreaRect(contour)
388
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
389
+
390
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
391
+ if points[1][1] > points[0][1]:
392
+ index_1 = 0
393
+ index_4 = 1
394
+ else:
395
+ index_1 = 1
396
+ index_4 = 0
397
+ if points[3][1] > points[2][1]:
398
+ index_2 = 2
399
+ index_3 = 3
400
+ else:
401
+ index_2 = 3
402
+ index_3 = 2
403
+
404
+ box = [
405
+ points[index_1], points[index_2], points[index_3], points[index_4]
406
+ ]
407
+ return box, min(bounding_box[1])
408
+
409
+ def box_score_fast(self, bitmap, _box):
410
+ h, w = bitmap.shape[:2]
411
+ box = _box.copy()
412
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
413
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
414
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
415
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
416
+
417
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
418
+ box[:, 0] = box[:, 0] - xmin
419
+ box[:, 1] = box[:, 1] - ymin
420
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
421
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
422
+
423
+ def box_score_slow(self, bitmap, contour):
424
+ '''
425
+ box_score_slow: use polyon mean score as the mean score
426
+ '''
427
+ h, w = bitmap.shape[:2]
428
+ contour = contour.copy()
429
+ contour = np.reshape(contour, (-1, 2))
430
+
431
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
432
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
433
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
434
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
435
+
436
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
437
+
438
+ contour[:, 0] = contour[:, 0] - xmin
439
+ contour[:, 1] = contour[:, 1] - ymin
440
+
441
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
442
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
443
+
444
+ def __call__(self, pred, shape_list):
445
+ pred = pred[:, 0, :, :]
446
+ segmentation = pred > self.thresh
447
+
448
+ boxes_batch = []
449
+ for batch_index in range(pred.shape[0]):
450
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
451
+ if self.dilation_kernel is not None:
452
+ mask = cv2.dilate(
453
+ np.array(segmentation[batch_index]).astype(np.uint8),
454
+ self.dilation_kernel)
455
+ else:
456
+ mask = segmentation[batch_index]
457
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
458
+ src_w, src_h)
459
+
460
+ boxes_batch.append({'points': boxes})
461
+ return boxes_batch
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.21.6
2
+ onnxruntime>=1.10.0
3
+ opencv_python
4
+ pyclipper>=1.2.1
5
+ Shapely
6
+ six