KennethTM commited on
Commit
f94a9ea
·
verified ·
1 Parent(s): 9d81a26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+ import cv2
6
+
7
+ image_size = 224
8
+
9
+ def normalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
10
+ image = (image/255.0).astype("float32")
11
+
12
+ image[:, :, 0] = (image[:, :, 0] - mean[0]) / std[0]
13
+ image[:, :, 1] = (image[:, :, 1] - mean[1]) / std[1]
14
+ image[:, :, 2] = (image[:, :, 2] - mean[2]) / std[2]
15
+
16
+ return image
17
+
18
+ def resize_longest_max_size(image, max_size=224):
19
+
20
+ height, width = image.shape[:2]
21
+
22
+ if width > height:
23
+ ratio = max_size / width
24
+ else:
25
+ ratio = max_size / height
26
+
27
+ new_width = int(width * ratio)
28
+ new_height = int(height * ratio)
29
+
30
+ resized_image = cv2.resize(image, (new_width, new_height), cv2.INTER_LINEAR)
31
+
32
+ return resized_image
33
+
34
+ def pad_if_needed(image, target_size):
35
+ height, width, _ = image.shape
36
+
37
+ y0 = abs((height-target_size)//2)
38
+ x0 = abs((width-target_size)//2)
39
+
40
+ background = np.zeros((target_size, target_size, 3), dtype="uint8")
41
+ background[y0:(y0+height), x0:(x0+width), :] = image
42
+
43
+ return(background)
44
+
45
+ def heatmap2keypoints(heatmap: np.ndarray, image_size: int = 224) -> list:
46
+ "Function to convert heatmap to keypoint x, y tensor"
47
+
48
+ indx = heatmap.reshape(-1, image_size*image_size).argmax(axis=1)
49
+ row = indx // image_size
50
+ col = indx % image_size
51
+
52
+ keypoints_array = np.stack((col, row), axis=1)
53
+ keypoints_list = keypoints_array.tolist()
54
+
55
+ return keypoints_list
56
+
57
+ def centercrop_keypoints(keypoints, crop_height, crop_width, image_height, image_width):
58
+ y_diff = (image_height-crop_height)//2
59
+ x_diff = (image_width-crop_width)//2
60
+
61
+ keypoints_crop = [[x-x_diff, y-y_diff] for x, y in keypoints]
62
+ return(keypoints_crop)
63
+
64
+ def resize_keypoints(keypoints, current_height, current_width, target_height, target_width):
65
+ keypoints_resize = []
66
+ for x, y in keypoints:
67
+ x_resize = (x/current_width)*target_width
68
+ y_resize = (y/current_height)*target_height
69
+ keypoints_resize.append([int(x_resize), int(y_resize)])
70
+ return(keypoints_resize)
71
+
72
+ def draw_keypoints(image, keypoints):
73
+ draw = ImageDraw.Draw(image)
74
+ w, h = image.size
75
+ for keypoint in keypoints:
76
+ x, y = keypoint
77
+ # Draw a small circle at each keypoint
78
+ radius = int(min(w, h) * 0.01)
79
+ draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
80
+ return image
81
+
82
+ def point_dist(p0, p1):
83
+ x0, y0 = p0
84
+ x1, y1 = p1
85
+
86
+ dist = ((x0-x1)**2 + (y0-y1)**2)**0.5
87
+
88
+ return dist
89
+
90
+ def receipt_asp_ratio(keypoints, mode = "mean"):
91
+
92
+ h0 = point_dist(keypoints[0], keypoints[3])
93
+ h1 = point_dist(keypoints[1], keypoints[2])
94
+
95
+ w0 = point_dist(keypoints[0], keypoints[1])
96
+ w1 = point_dist(keypoints[2], keypoints[3])
97
+
98
+ if mode == "max":
99
+ h = max(h0, h1)
100
+ w = max(w0, w1)
101
+ elif mode == "mean":
102
+ h = (h0+h1)/2
103
+ w = (w0+w1)/2
104
+ else:
105
+ return("UNKNOWN MODE")
106
+
107
+ return w/h
108
+
109
+ # Load the ONNX model
110
+ session = ort.InferenceSession("models/timm-mobilenetv3_small_100.onnx")
111
+
112
+ input_name = session.get_inputs()[0].name
113
+ output_name = session.get_outputs()[0].name
114
+
115
+ # Main function to handle the image input, apply preprocessing, run the model, and apply postprocessing
116
+ def process_image(input_image):
117
+
118
+ # Convert PIL image to OpenCV image
119
+ image = np.array(input_image.convert("RGB"))
120
+
121
+ h, w, _ = image.shape
122
+
123
+ # Preprocess the image
124
+ image_resize = resize_longest_max_size(image)
125
+
126
+ h_small, w_small, _ = image_resize.shape
127
+
128
+ image_pad = pad_if_needed(image_resize, target_size=image_size)
129
+
130
+ image_norm = normalize_image(image_pad)
131
+
132
+ image_array = np.transpose(image_norm, (2, 0, 1))
133
+
134
+ image_array = np.expand_dims(image_array, axis=0)
135
+
136
+ # Run model inference
137
+ output = session.run([output_name], {input_name: image_array})
138
+
139
+ output_keypoints = heatmap2keypoints(output[0].squeeze())
140
+ crop_keypoints = centercrop_keypoints(output_keypoints, h_small, w_small, image_size, image_size)
141
+ large_keypoints = resize_keypoints(crop_keypoints, h_small, w_small, h, w)
142
+
143
+ # Draw keypoints on the image
144
+ image_with_keypoints = draw_keypoints(input_image, large_keypoints)
145
+
146
+ persp_h = 1024
147
+ persp_asp = receipt_asp_ratio(large_keypoints, mode="max")
148
+ persp_w = int(persp_asp*persp_h)
149
+
150
+ origin_points = np.float32([[x, y] for x, y in large_keypoints])
151
+ target_points = np.float32([[0, 0], [persp_w-1, 0], [persp_w-1, persp_h-1], [0, persp_h-1]])
152
+
153
+ persp_matrix = cv2.getPerspectiveTransform(origin_points, target_points)
154
+ persp_image = cv2.warpPerspective(image, persp_matrix, (persp_w, persp_h), cv2.INTER_LINEAR)
155
+
156
+ output_image = Image.fromarray(persp_image)
157
+
158
+ return image_with_keypoints, output_image
159
+
160
+ demo_images = [
161
+ "demo_images/image_1.jpg",
162
+ "demo_images/image_2.jpg",
163
+ "demo_images/image_3.jpg",
164
+ "demo_images/image_flux_1.png",
165
+ "demo_images/image_flux_2.png",
166
+ ]
167
+
168
+ # Create Gradio interface
169
+ with gr.Blocks() as iface:
170
+ gr.Markdown("# Document corner detection and perspective correction")
171
+ gr.Markdown("Upload an image to detect the corners of a document and correct the perspective.\n\nUses a UNet model to detect corners and OpenCV to correct the perspective.")
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ input_image = gr.Image(type="pil", label="Image", show_label=True, scale=1)
176
+
177
+ with gr.Column():
178
+ output_image1 = gr.Image(type="pil", label="Image with predicted corners", show_label=True, scale=1)
179
+
180
+ with gr.Column():
181
+ output_image2 = gr.Image(type="pil", label="Image with perspective correction", show_label=True, scale=1)
182
+
183
+ with gr.Row():
184
+ examples = gr.Examples(demo_images, input_image, cache_examples=False, label="Exampled documents (CORD dataset and FLUX.1-schnell generated)")
185
+
186
+ input_image.change(fn=process_image, inputs=input_image, outputs=[output_image1, output_image2])
187
+
188
+ gr.Markdown("Created by Kenneth Thorø Martinsen (kenneth2810@gmail.com)")
189
+
190
+ iface.launch()