bachpc commited on
Commit
aaffb60
·
1 Parent(s): 8ed655c

Minor changes

Browse files
Files changed (1) hide show
  1. app.py +106 -61
app.py CHANGED
@@ -29,20 +29,28 @@ structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/struct
29
 
30
  imgsz = 640
31
 
32
- detection_class_names = ['table', 'table rotated']
33
  structure_class_names = [
34
  'table', 'table column', 'table row', 'table column header',
35
  'table projected row header', 'table spanning cell', 'no object'
36
  ]
 
 
37
  structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
 
 
 
 
 
 
38
  structure_class_thresholds = {
39
- "table": 0.42,
40
- "table column": 0.56,
41
- "table row": 0.5,
42
- "table column header": 0.38,
43
- "table projected row header": 0.27,
44
- "table spanning cell": 0.4,
45
- "no object": 10
46
  }
47
 
48
 
@@ -84,6 +92,9 @@ def crop_image(pil_img, detection_result, padding=30):
84
  w = result[2]
85
  h = result[3]
86
 
 
 
 
87
  x1 = int((min_x - w / 2) * width)
88
  y1 = int((min_y - h / 2) * height)
89
  x2 = int((min_x + w / 2) * width)
@@ -97,7 +108,7 @@ def crop_image(pil_img, detection_result, padding=30):
97
 
98
  crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :]
99
  crop_image = cv_to_PIL(crop_image)
100
- if class_id == 1: # table rotated
101
  crop_image = crop_image.rotate(270, expand=True)
102
 
103
  crop_images.append(crop_image)
@@ -180,17 +191,49 @@ def convert_stucture(page_tokens, pil_img, structure_result):
180
  return table_structures, cells, confidence_score
181
 
182
 
 
 
 
 
 
 
 
 
 
 
183
  def visualize_ocr(pil_img, ocr_result):
184
- image = PIL_to_cv(pil_img)
185
- for i, res in enumerate(ocr_result):
186
- bbox = res['bbox']
187
- x1 = int(bbox[0])
188
- y1 = int(bbox[1])
189
- x2 = int(bbox[2])
190
- y2 = int(bbox[3])
191
- cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
192
- cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0))
193
- return cv_to_PIL(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
 
196
  def get_bbox_decorations(data_type, label):
@@ -231,6 +274,9 @@ def visualize_structure(pil_img, structure_result):
231
  w = result[2]
232
  h = result[3]
233
 
 
 
 
234
  x1 = int((min_x - w / 2) * width)
235
  y1 = int((min_y - h / 2) * height)
236
  x2 = int((min_x + w / 2) * width)
@@ -238,35 +284,31 @@ def visualize_structure(pil_img, structure_result):
238
  # print(x1, y1, x2, y2)
239
  bbox = [x1, y1, x2, y2]
240
 
241
- if score >= structure_class_thresholds[structure_class_names[class_id]]:
242
- #cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
243
- #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
244
-
245
- color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
246
- # Fill
247
- rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
248
- linewidth=linewidth, alpha=alpha,
249
- edgecolor='none',facecolor=color,
250
- linestyle=None)
251
- ax.add_patch(rect)
252
- # Hatch
253
- rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
254
- linewidth=1, alpha=0.4,
255
- edgecolor=color, facecolor='none',
256
- linestyle='--',hatch=hatch)
257
- ax.add_patch(rect)
258
- # Edge
259
- rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
260
- linewidth=linewidth,
261
- edgecolor=color, facecolor='none',
262
- linestyle="--")
263
- ax.add_patch(rect)
264
 
265
  plt.xticks([], [])
266
  plt.yticks([], [])
267
 
268
  legend_elements = []
269
- for class_name in structure_class_names:
270
  color, alpha, linewidth, hatch = get_bbox_decorations('recognition', structure_class_map[class_name])
271
  legend_elements.append(
272
  Patch(facecolor=color, edgecolor=color, label=class_name, hatch=hatch, alpha=alpha)
@@ -506,10 +548,7 @@ def cells_to_excel(cells, file_path):
506
  workbook = xlsxwriter.Workbook(file_path)
507
 
508
  cell_format = workbook.add_format(
509
- {
510
- 'align': 'center',
511
- 'valign': 'vcenter',
512
- }
513
  )
514
 
515
  worksheet = workbook.add_worksheet(name='Table')
@@ -573,33 +612,35 @@ def main():
573
  with tabs[1]:
574
  st.header('Table Structure Recognition')
575
 
576
- str_cols = st.columns((len(crop_images), ) * 4)
577
  str_cols[0].subheader('Table image')
578
  str_cols[1].subheader('OCR result')
579
  str_cols[2].subheader('Structure result')
580
  str_cols[3].subheader('Cells result')
581
 
582
  for i, img in enumerate(crop_images):
 
 
 
 
 
583
  ocr_result = ocr(img)
 
 
 
584
  structure_result = table_structure(img)
 
 
 
585
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
586
  cells = extract_text_from_cells(cells)
587
- all_cells.append(cells)
588
- html_result = cells_to_html(cells)
589
- #df, csv_result = cells_to_csv(cells)
590
- #print(df)
591
-
592
- vis_ocr_img = visualize_ocr(img, ocr_result)
593
- vis_str_img = visualize_structure(img, structure_result)
594
  vis_cells_img = visualize_cells(img, cells)
595
-
596
- str_cols[0].image(img)
597
- str_cols[1].image(vis_ocr_img)
598
- str_cols[2].image(vis_str_img)
599
  str_cols[3].image(vis_cells_img)
600
 
601
- st.write('\n')
602
- st.markdown(html_result, unsafe_allow_html=True)
 
 
603
 
604
  with tabs[2]:
605
  st.header('Extracted Table(s)')
@@ -621,6 +662,10 @@ def main():
621
  file_name=f'output_{idx}.xlsx',
622
  )
623
 
 
 
 
 
624
 
625
  if __name__ == '__main__':
626
  main()
 
29
 
30
  imgsz = 640
31
 
32
+ detection_class_names = ['table', 'table rotated', 'no object']
33
  structure_class_names = [
34
  'table', 'table column', 'table row', 'table column header',
35
  'table projected row header', 'table spanning cell', 'no object'
36
  ]
37
+
38
+ detection_class_map = {k: v for v, k in enumerate(detection_class_names)}
39
  structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
40
+
41
+ detection_class_thresholds = {
42
+ 'table': 0.5,
43
+ 'table rotated': 0.5,
44
+ 'no object': 10
45
+ }
46
  structure_class_thresholds = {
47
+ 'table': 0.42,
48
+ 'table column': 0.56,
49
+ 'table row': 0.5,
50
+ 'table column header': 0.38,
51
+ 'table projected row header': 0.27,
52
+ 'table spanning cell': 0.4,
53
+ 'no object': 10
54
  }
55
 
56
 
 
92
  w = result[2]
93
  h = result[3]
94
 
95
+ if score < detection_class_thresholds[detection_class_names[class_id]]:
96
+ continue
97
+
98
  x1 = int((min_x - w / 2) * width)
99
  y1 = int((min_y - h / 2) * height)
100
  x2 = int((min_x + w / 2) * width)
 
108
 
109
  crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :]
110
  crop_image = cv_to_PIL(crop_image)
111
+ if detection_class_names[class_id] == 'table rotated':
112
  crop_image = crop_image.rotate(270, expand=True)
113
 
114
  crop_images.append(crop_image)
 
191
  return table_structures, cells, confidence_score
192
 
193
 
194
+ def visualize_image(pil_img):
195
+ plt.imshow(pil_img, interpolation='lanczos')
196
+ plt.gcf().set_size_inches(10, 10)
197
+ plt.axis('off')
198
+ img_buf = io.BytesIO()
199
+ plt.savefig(img_buf, bbox_inches='tight', dpi=150)
200
+ plt.close()
201
+ return PIL.Image.open(img_buf)
202
+
203
+
204
  def visualize_ocr(pil_img, ocr_result):
205
+ plt.imshow(pil_img, interpolation='lanczos')
206
+ plt.gcf().set_size_inches(20, 20)
207
+ ax = plt.gca()
208
+
209
+ for i, result in enumerate(ocr_result):
210
+ bbox = result['bbox']
211
+ text = result['text']
212
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle="-")
213
+ ax.add_patch(rect)
214
+ ax.text(bbox[0], bbox[3], text, horizontalalignment='left', verticalalignment='bottom', transform=ax.transAxes, color='blue')
215
+
216
+ plt.xticks([], [])
217
+ plt.yticks([], [])
218
+
219
+ plt.gcf().set_size_inches(10, 10)
220
+ plt.axis('off')
221
+ img_buf = io.BytesIO()
222
+ plt.savefig(img_buf, bbox_inches='tight', dpi=150)
223
+ plt.close()
224
+
225
+ return PIL.Image.open(img_buf)
226
+
227
+ # image = PIL_to_cv(pil_img)
228
+ # for i, res in enumerate(ocr_result):
229
+ # bbox = res['bbox']
230
+ # x1 = int(bbox[0])
231
+ # y1 = int(bbox[1])
232
+ # x2 = int(bbox[2])
233
+ # y2 = int(bbox[3])
234
+ # cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0))
235
+ # cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(0, 0, 255))
236
+ # return cv_to_PIL(image)
237
 
238
 
239
  def get_bbox_decorations(data_type, label):
 
274
  w = result[2]
275
  h = result[3]
276
 
277
+ if score < structure_class_thresholds[structure_class_names[class_id]]:
278
+ continue
279
+
280
  x1 = int((min_x - w / 2) * width)
281
  y1 = int((min_y - h / 2) * height)
282
  x2 = int((min_x + w / 2) * width)
 
284
  # print(x1, y1, x2, y2)
285
  bbox = [x1, y1, x2, y2]
286
 
287
+ color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
288
+ # Fill
289
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
290
+ linewidth=linewidth, alpha=alpha,
291
+ edgecolor='none',facecolor=color,
292
+ linestyle=None)
293
+ ax.add_patch(rect)
294
+ # Hatch
295
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
296
+ linewidth=1, alpha=0.4,
297
+ edgecolor=color, facecolor='none',
298
+ linestyle='--',hatch=hatch)
299
+ ax.add_patch(rect)
300
+ # Edge
301
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
302
+ linewidth=linewidth,
303
+ edgecolor=color, facecolor='none',
304
+ linestyle="--")
305
+ ax.add_patch(rect)
 
 
 
 
306
 
307
  plt.xticks([], [])
308
  plt.yticks([], [])
309
 
310
  legend_elements = []
311
+ for class_name in structure_class_names[:-1]:
312
  color, alpha, linewidth, hatch = get_bbox_decorations('recognition', structure_class_map[class_name])
313
  legend_elements.append(
314
  Patch(facecolor=color, edgecolor=color, label=class_name, hatch=hatch, alpha=alpha)
 
548
  workbook = xlsxwriter.Workbook(file_path)
549
 
550
  cell_format = workbook.add_format(
551
+ {'align': 'center', 'valign': 'vcenter'}
 
 
 
552
  )
553
 
554
  worksheet = workbook.add_worksheet(name='Table')
 
612
  with tabs[1]:
613
  st.header('Table Structure Recognition')
614
 
615
+ str_cols = st.columns(4)
616
  str_cols[0].subheader('Table image')
617
  str_cols[1].subheader('OCR result')
618
  str_cols[2].subheader('Structure result')
619
  str_cols[3].subheader('Cells result')
620
 
621
  for i, img in enumerate(crop_images):
622
+ str_cols = st.columns(4)
623
+
624
+ vis_img = visualize_image(img)
625
+ str_cols[0].image(vis_img)
626
+
627
  ocr_result = ocr(img)
628
+ vis_ocr_img = visualize_ocr(img, ocr_result)
629
+ str_cols[1].image(vis_ocr_img)
630
+
631
  structure_result = table_structure(img)
632
+ vis_str_img = visualize_structure(img, structure_result)
633
+ str_cols[2].image(vis_str_img)
634
+
635
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
636
  cells = extract_text_from_cells(cells)
 
 
 
 
 
 
 
637
  vis_cells_img = visualize_cells(img, cells)
 
 
 
 
638
  str_cols[3].image(vis_cells_img)
639
 
640
+ all_cells.append(cells)
641
+
642
+ #df, csv_result = cells_to_csv(cells)
643
+ #print(df)
644
 
645
  with tabs[2]:
646
  st.header('Extracted Table(s)')
 
662
  file_name=f'output_{idx}.xlsx',
663
  )
664
 
665
+ for idx, cells in enumerate(all_cells):
666
+ html_result = cells_to_html(cells)
667
+ st.subheader(f'HTML Table {idx + 1}')
668
+ st.markdown(html_result, unsafe_allow_html=True)
669
 
670
  if __name__ == '__main__':
671
  main()