mscsasem3 commited on
Commit
7af8c25
1 Parent(s): 86388fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -4
app.py CHANGED
@@ -1,7 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ # import gradio as gr
2
+
3
+ # def greet(name):
4
+ # return "Hello " + name + "!!"
5
+
6
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ # iface.launch()
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
+ from PIL import Image
10
+ import requests
11
+ import warnings
12
+ from skimage.io import imread
13
+ from skimage.color import rgb2gray
14
+ import matplotlib.pyplot as plt
15
+ from skimage.filters import sobel
16
+ import numpy as np
17
+ from heapq import *
18
  import gradio as gr
19
+ from skimage.filters import threshold_otsu
20
+ from skimage.util import invert
21
+ import cv2,imageio
22
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
23
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
24
+ plt.switch_backend('Agg')
25
+ def horizontal_projections(sobel_image):
26
+ return np.sum(sobel_image, axis=1)
27
+
28
+
29
+ def find_peak_regions(hpp, divider=4):
30
+ threshold = (np.max(hpp)-np.min(hpp))/divider
31
+ peaks = []
32
+
33
+ for i, hppv in enumerate(hpp):
34
+ if hppv < threshold:
35
+ peaks.append([i, hppv])
36
+ return peaks
37
+
38
+ def heuristic(a, b):
39
+ return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2
40
+
41
+ def get_hpp_walking_regions(peaks_index):
42
+ hpp_clusters = []
43
+ cluster = []
44
+ for index, value in enumerate(peaks_index):
45
+ cluster.append(value)
46
+
47
+ if index < len(peaks_index)-1 and peaks_index[index+1] - value > 1:
48
+ hpp_clusters.append(cluster)
49
+ cluster = []
50
+
51
+ #get the last cluster
52
+ if index == len(peaks_index)-1:
53
+ hpp_clusters.append(cluster)
54
+ cluster = []
55
+
56
+ return hpp_clusters
57
+
58
+ def astar(array, start, goal):
59
+
60
+ neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)]
61
+ close_set = set()
62
+ came_from = {}
63
+ gscore = {start:0}
64
+ fscore = {start:heuristic(start, goal)}
65
+ oheap = []
66
+
67
+ heappush(oheap, (fscore[start], start))
68
+
69
+ while oheap:
70
+
71
+ current = heappop(oheap)[1]
72
+
73
+ if current == goal:
74
+ data = []
75
+ while current in came_from:
76
+ data.append(current)
77
+ current = came_from[current]
78
+ return data
79
+
80
+ close_set.add(current)
81
+ for i, j in neighbors:
82
+ neighbor = current[0] + i, current[1] + j
83
+ tentative_g_score = gscore[current] + heuristic(current, neighbor)
84
+ if 0 <= neighbor[0] < array.shape[0]:
85
+ if 0 <= neighbor[1] < array.shape[1]:
86
+ if array[neighbor[0]][neighbor[1]] == 1:
87
+ continue
88
+ else:
89
+ # array bound y walls
90
+ continue
91
+ else:
92
+ # array bound x walls
93
+ continue
94
+
95
+ if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
96
+ continue
97
+
98
+ if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1]for i in oheap]:
99
+ came_from[neighbor] = current
100
+ gscore[neighbor] = tentative_g_score
101
+ fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
102
+ heappush(oheap, (fscore[neighbor], neighbor))
103
+
104
+ return []
105
+
106
+ def get_binary(img):
107
+ mean = np.mean(img)
108
+ if mean == 0.0 or mean == 1.0:
109
+ return img
110
+
111
+ thresh = threshold_otsu(img)
112
+ binary = img <= thresh
113
+ binary = binary*1
114
+ return binary
115
+
116
+ def path_exists(window_image):
117
+ #very basic check first then proceed to A* check
118
+ if 0 in horizontal_projections(window_image):
119
+ return True
120
+
121
+ padded_window = np.zeros((window_image.shape[0],1))
122
+ world_map = np.hstack((padded_window, np.hstack((window_image,padded_window)) ) )
123
+ path = np.array(astar(world_map, (int(world_map.shape[0]/2), 0), (int(world_map.shape[0]/2), world_map.shape[1])))
124
+ if len(path) > 0:
125
+ return True
126
+
127
+ return False
128
+
129
+ def get_road_block_regions(nmap):
130
+ road_blocks = []
131
+ needtobreak = False
132
+
133
+ for col in range(nmap.shape[1]):
134
+ start = col
135
+ end = col+20
136
+ if end > nmap.shape[1]-1:
137
+ end = nmap.shape[1]-1
138
+ needtobreak = True
139
+
140
+ if path_exists(nmap[:, start:end]) == False:
141
+ road_blocks.append(col)
142
+
143
+ if needtobreak == True:
144
+ break
145
+
146
+ return road_blocks
147
+
148
+ def group_the_road_blocks(road_blocks):
149
+ #group the road blocks
150
+ road_blocks_cluster_groups = []
151
+ road_blocks_cluster = []
152
+ size = len(road_blocks)
153
+ for index, value in enumerate(road_blocks):
154
+ road_blocks_cluster.append(value)
155
+ if index < size-1 and (road_blocks[index+1] - road_blocks[index]) > 1:
156
+ road_blocks_cluster_groups.append([road_blocks_cluster[0], road_blocks_cluster[len(road_blocks_cluster)-1]])
157
+ road_blocks_cluster = []
158
+
159
+ if index == size-1 and len(road_blocks_cluster) > 0:
160
+ road_blocks_cluster_groups.append([road_blocks_cluster[0], road_blocks_cluster[len(road_blocks_cluster)-1]])
161
+ road_blocks_cluster = []
162
+
163
+ return road_blocks_cluster_groups
164
+
165
+ def extract_line_from_image(image, lower_line, upper_line):
166
+ lower_boundary = np.min(lower_line[:, 0])
167
+ upper_boundary = np.min(upper_line[:, 0])
168
+ img_copy = np.copy(image)
169
+ r, c = img_copy.shape
170
+ for index in range(c-1):
171
+ img_copy[0:lower_line[index, 0], index] = 0
172
+ img_copy[upper_line[index, 0]:r, index] = 0
173
+
174
+ return img_copy[lower_boundary:upper_boundary, :]
175
+
176
+ def extract(image):
177
+ img = rgb2gray(image)
178
+
179
+ #img = rgb2gray(imread("Penwritten_2048x.jpeg"))
180
+ #img = rgb2gray(imread("test.jpg"))
181
+ #img = rgb2gray(imread(""))
182
+
183
+
184
+
185
+
186
+ sobel_image = sobel(img)
187
+ hpp = horizontal_projections(sobel_image)
188
+
189
+
190
+ warnings.filterwarnings("ignore")
191
+ #find the midway where we can make a threshold and extract the peaks regions
192
+ #divider parameter value is used to threshold the peak values from non peak values.
193
+
194
+
195
+ peaks = find_peak_regions(hpp)
196
+
197
+ peaks_index = np.array(peaks)[:,0].astype(int)
198
+ #print(peaks_index.shape)
199
+ segmented_img = np.copy(img)
200
+ r= segmented_img.shape
201
+ for ri in range(r[0]):
202
+ if ri in peaks_index:
203
+ segmented_img[ri, :] = 0
204
+
205
+ #group the peaks into walking windows
206
+
207
+
208
+ hpp_clusters = get_hpp_walking_regions(peaks_index)
209
+ #a star path planning algorithm
210
+
211
+
212
+
213
+
214
+
215
+
216
+
217
+ #Scan the paths to see if there are any blockers.
218
+
219
+
220
+
221
+
222
+ binary_image = get_binary(img)
223
+
224
+ for cluster_of_interest in hpp_clusters:
225
+ nmap = binary_image[cluster_of_interest[0]:cluster_of_interest[len(cluster_of_interest)-1],:]
226
+ road_blocks = get_road_block_regions(nmap)
227
+ road_blocks_cluster_groups = group_the_road_blocks(road_blocks)
228
+ #create the doorways
229
+ for index, road_blocks in enumerate(road_blocks_cluster_groups):
230
+ window_image = nmap[:, road_blocks[0]: road_blocks[1]+10]
231
+ binary_image[cluster_of_interest[0]:cluster_of_interest[len(cluster_of_interest)-1],:][:, road_blocks[0]: road_blocks[1]+10][int(window_image.shape[0]/2),:] *= 0
232
+
233
+ #now that everything is cleaner, its time to segment all the lines using the A* algorithm
234
+ line_segments = []
235
+ #print(len(hpp_clusters))
236
+ #print(hpp_clusters)
237
+ for i, cluster_of_interest in enumerate(hpp_clusters):
238
+ nmap = binary_image[cluster_of_interest[0]:cluster_of_interest[len(cluster_of_interest)-1],:]
239
+ path = np.array(astar(nmap, (int(nmap.shape[0]/2), 0), (int(nmap.shape[0]/2),nmap.shape[1]-1)))
240
+ #print(path.shape)
241
+ if path.shape[0]!=0:
242
+ #break
243
+ offset_from_top = cluster_of_interest[0]
244
+ #print(offset_from_top)
245
+ path[:,0] += offset_from_top
246
+ #print(path)
247
+ line_segments.append(path)
248
+ #print(i)
249
+
250
+ cluster_of_interest = hpp_clusters[1]
251
+ offset_from_top = cluster_of_interest[0]
252
+ nmap = binary_image[cluster_of_interest[0]:cluster_of_interest[len(cluster_of_interest)-1],:]
253
+ #plt.figure(figsize=(20,20))
254
+ #plt.imshow(invert(nmap), cmap="gray")
255
+
256
+ path = np.array(astar(nmap, (int(nmap.shape[0]/2), 0), (int(nmap.shape[0]/2),nmap.shape[1]-1)))
257
+ #plt.plot(path[:,1], path[:,0])
258
+
259
+ offset_from_top = cluster_of_interest[0]
260
+
261
+
262
+
263
+ ## add an extra line to the line segments array which represents the last bottom row on the image
264
+ last_bottom_row = np.flip(np.column_stack(((np.ones((img.shape[1],))*img.shape[0]), np.arange(img.shape[1]))).astype(int), axis=0)
265
+ line_segments.append(last_bottom_row)
266
+
267
+ line_images = []
268
+
269
+
270
+
271
+
272
+ line_count = len(line_segments)
273
+ fig, ax = plt.subplots(figsize=(10,10), nrows=line_count-1)
274
+ output = []
275
+
276
+
277
+ for line_index in range(line_count-1):
278
+ line_image = extract_line_from_image(img, line_segments[line_index], line_segments[line_index+1])
279
+ line_images.append(line_image)
280
+ #print(line_image)
281
+ #cv2.imwrite('/Users/vatsalya/Desktop/demo.jpeg',line_image)
282
+
283
+
284
+ #im=Image.fromarray(line_image)
285
+ #im=im.convert("L")
286
+ #im.save("/Users/vatsalya/Desktop/demo.jpeg")
287
+ #print("#### Image Saved #######")
288
+ imageio.imwrite('demo.jpeg',line_image)
289
+
290
+
291
+
292
+ image = Image.open("demo.jpeg").convert("RGB")
293
+ #print("Started Processing")
294
+
295
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
296
+
297
+ generated_ids = model.generate(pixel_values)
298
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
299
+ print(generated_text)
300
+ output.append(generated_text)
301
+ #ax[line_index].imshow(line_image, cmap="gray")
302
+ result=""
303
+ for o in output:
304
+ result=result+o
305
+ result=result+" "
306
+ return result
307
 
308
+ iface = gr.Interface(fn=extract,
309
+ inputs=[gr.inputs.Image(type='file', label='Ideal Answer'),gr.inputs.Image(type='file', label='Ideal Answer Diagram'),gr.inputs.Image(type='file', label='Submitted Answer'),gr.inputs.Image(type='file', label='Submitted Answer Diagram')]
310
+ outputs=gr.outputs.Textbox(),)
311
 
312
+ iface.launch(enable_queue=True)