MikkoLipsanen commited on
Commit
f731714
·
verified ·
1 Parent(s): aca3760

Create onnx_text_recognition.py

Browse files
Files changed (1) hide show
  1. onnx_text_recognition.py +113 -0
onnx_text_recognition.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optimum.onnxruntime import ORTModelForVision2Seq
2
+ from transformers import TrOCRProcessor
3
+ import numpy as np
4
+ import onnxruntime
5
+ import math
6
+ import cv2
7
+ import os
8
+
9
+ class TextRecognition:
10
+ def __init__(self,
11
+ processor_path,
12
+ model_path,
13
+ device = 'cpu',
14
+ half_precision = False,
15
+ line_threshold = 120):
16
+ self.device = device
17
+ self.half_precision = half_precision
18
+ self.line_threshold = line_threshold
19
+ self.processor_path = processor_path
20
+ self.model_path = model_path
21
+ self.processor = self.init_processor()
22
+ self.recognition_model = self.init_recognition_model()
23
+
24
+ def init_processor(self):
25
+ """Function for initializing the processor."""
26
+ try:
27
+ processor = TrOCRProcessor.from_pretrained(self.processor_path)
28
+ return processor
29
+ except Exception as e:
30
+ print('Failed to initialize processor: %s' % e)
31
+
32
+ def init_recognition_model(self):
33
+ """Function for initializing the text detection model."""
34
+ sess_options = onnxruntime.SessionOptions()
35
+ sess_options.intra_op_num_threads = 3
36
+ sess_options.inter_op_num_threads = 3
37
+ try:
38
+ recognition_model = ORTModelForVision2Seq.from_pretrained(self.model_path)#, session_options=sess_options, provider="CUDAExecutionProvider")
39
+ return recognition_model
40
+ except Exception as e:
41
+ print('Failed to load the text recognition model: %s' % e)
42
+
43
+ def crop_line(self, image, polygon, height, width):
44
+ """Crops predicted text line based on the polygon coordinates
45
+ and returns binarised text line image."""
46
+ poly = np.array([[int(lst[0]), int(lst[1])] for lst in polygon])
47
+ mask = np.zeros([height, width], dtype=np.uint8)
48
+ cv2.drawContours(mask, [poly], -1, (255, 255, 255), -1, cv2.LINE_AA)
49
+ rect = cv2.boundingRect(poly)
50
+ cropped = image[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
51
+
52
+ mask_crop = mask[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
53
+ res = cv2.bitwise_and(cropped, cropped, mask = mask_crop)
54
+
55
+ wbg = np.ones_like(cropped, np.uint8) * 255
56
+ cv2.bitwise_not(wbg,wbg, mask=mask_crop)
57
+ row_image = wbg+res
58
+ return row_image
59
+
60
+ def crop_lines(self, polygons, image, height, width):
61
+ """Returns a list of line images cropped following the detected polygon coordinates."""
62
+ cropped_lines = []
63
+ for i, polygon in enumerate(polygons):
64
+ cropped_line = self.crop_line(image, polygon, height, width)
65
+ cropped_lines.append(cropped_line)
66
+ return cropped_lines
67
+
68
+ def get_scores(self, lgscores):
69
+ """Get exponent of log scores."""
70
+ scores = []
71
+ for lgscore in lgscores:
72
+ score = math.exp(lgscore)
73
+ scores.append(score)
74
+ return scores
75
+
76
+ def predict_text(self, cropped_lines):
77
+ """Functions for predicting text content from the cropped line images."""
78
+ pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
79
+ generated_dict = self.recognition_model.generate(pixel_values.to(self.device), max_new_tokens=128, return_dict_in_generate=True, output_scores=True)
80
+ generated_ids, lgscores = generated_dict['sequences'], generated_dict['sequences_scores']
81
+ scores = self.get_scores(lgscores.tolist())
82
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
83
+ return scores, generated_text
84
+
85
+ def get_text_lines(self, cropped_lines):
86
+ scores, generated_text = [], []
87
+ if len(cropped_lines) <= self.line_threshold:
88
+ scores, generated_text = self.predict_text(cropped_lines)
89
+ else:
90
+ n = math.ceil(len(cropped_lines) / self.line_threshold)
91
+ for i in range(n):
92
+ start = int(i * self.line_threshold)
93
+ end = int(min(start + self.line_threshold, len(cropped_lines)))
94
+ sc, gt = self.predict_text(cropped_lines[start:end])
95
+ scores += sc
96
+ generated_text += gt
97
+ return scores, generated_text
98
+
99
+ def get_res_dict(self, polygons, generated_text, height, width, image_name, line_confs, scores):
100
+ """Combines the results in a dictionary form."""
101
+ line_dicts = []
102
+ for i in range(len(generated_text)):
103
+ line_dict = {'polygon': polygons[i], 'text': generated_text[i], 'conf': line_confs[i], 'text_conf':scores[i]}
104
+ line_dicts.append(line_dict)
105
+ lines_dict = {'img_name': image_name, 'height': height, 'width': width, 'text_lines': line_dicts}
106
+ return lines_dict
107
+
108
+ def process_lines(self, polygons, image, height, width):
109
+ # Crop line images
110
+ cropped_lines = self.crop_lines(polygons, image, height, width)
111
+ # Get text predictions
112
+ scores, generated_text = self.get_text_lines(cropped_lines)
113
+ return generated_text