bachpc commited on
Commit
ba538d2
1 Parent(s): 76f6ada

Add detection

Browse files
Files changed (4) hide show
  1. app.py +261 -102
  2. postprocess.py +35 -37
  3. requirements.txt +77 -18
  4. weights/detection_wts.pt +3 -0
app.py CHANGED
@@ -1,30 +1,28 @@
1
  import streamlit as st
2
 
3
  import PIL
 
4
  import numpy as np
 
5
  import torch
6
- from collections import defaultdict
7
-
8
- import cv2
9
- from doctr.io import DocumentFile
10
- from doctr.models import ocr_predictor
11
- from doctr.utils.visualization import visualize_page
12
 
 
13
  import pytesseract
14
  from pytesseract import Output
15
 
16
- from bs4 import BeautifulSoup as bs
17
- from html import escape
18
-
19
- import sys, json
20
-
21
  import postprocess
22
 
23
 
24
- ocr_predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
 
25
  structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
26
  imgsz = 640
27
 
 
28
  structure_class_names = [
29
  'table', 'table column', 'table row', 'table column header',
30
  'table projected row header', 'table spanning cell', 'no object'
@@ -42,15 +40,22 @@ structure_class_thresholds = {
42
 
43
 
44
  def PIL_to_cv(pil_img):
45
- return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
46
 
47
 
48
  def cv_to_PIL(cv_img):
49
  return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
50
 
51
 
52
- def table_structure(filename):
53
- pil_img = PIL.Image.open(filename)
 
 
 
 
 
 
 
54
  image = PIL_to_cv(pil_img)
55
  pred = structure_model(image, size=imgsz)
56
  pred = pred.xywhn[0]
@@ -58,32 +63,59 @@ def table_structure(filename):
58
  return result
59
 
60
 
61
- def ocr(filename):
62
- doc = DocumentFile.from_images(filename.read())
63
- result = ocr_predictor(doc).export()
64
- result = result['pages'][0]
65
- H, W = result['dimensions']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ocr_res = []
67
- for block in result['blocks']:
68
- for line in block['lines']:
69
- for word in line['words']:
70
- bbox = word['geometry']
71
- word_info = {
72
- 'bbox': [int(bbox[0][0] * W), int(bbox[0][1] * H), int(bbox[1][0] * W), int(bbox[1][1] * H)],
73
- 'text': word['value']
74
- }
75
- ocr_res.append(word_info)
 
 
 
76
  return ocr_res
77
 
78
 
79
- def convert_stucture(page_tokens, filename, structure_result):
80
- pil_img = PIL.Image.open(filename)
81
  image = PIL_to_cv(pil_img)
82
 
83
  width = image.shape[1]
84
  height = image.shape[0]
85
  # print(width, height)
86
-
87
  bboxes = []
88
  scores = []
89
  labels = []
@@ -94,11 +126,11 @@ def convert_stucture(page_tokens, filename, structure_result):
94
  min_y = result[1]
95
  w = result[2]
96
  h = result[3]
97
-
98
- x1 = int((min_x-w/2)*width)
99
- y1 = int((min_y-h/2)*height)
100
- x2 = int((min_x+w/2)*width)
101
- y2 = int((min_y+h/2)*height)
102
  # print(x1, y1, x2, y2)
103
 
104
  bboxes.append([x1, y1, x2, y2])
@@ -109,9 +141,9 @@ def convert_stucture(page_tokens, filename, structure_result):
109
  for bbox, score, label in zip(bboxes, scores, labels):
110
  table_objects.append({'bbox': bbox, 'score': score, 'label': label})
111
  # print('table_objects:', table_objects)
112
-
113
  table = {'objects': table_objects, 'page_num': 0}
114
-
115
  table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
116
  if len(table_class_objects) > 1:
117
  table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
@@ -121,17 +153,54 @@ def convert_stucture(page_tokens, filename, structure_result):
121
  table_bbox = (0,0,1000,1000)
122
  # print('table_class_objects:', table_class_objects)
123
  # print('table_bbox:', table_bbox)
124
-
125
  tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
126
  # print('tokens_in_table:', tokens_in_table)
127
-
128
  table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
129
-
130
  return table_structures, cells, confidence_score
131
 
132
 
133
- def visualize_cells(filename, cells, ax):
134
- pil_img = PIL.Image.open(filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  image = PIL_to_cv(pil_img)
136
  for i, cell in enumerate(cells):
137
  bbox = cell['bbox']
@@ -140,7 +209,7 @@ def visualize_cells(filename, cells, ax):
140
  x2 = int(bbox[2])
141
  y2 = int(bbox[3])
142
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
143
- ax.image(cv_to_PIL(image))
144
 
145
 
146
  def pytess(cell_pil_img):
@@ -175,55 +244,125 @@ def remove_noise_and_smooth(pil_img):
175
  return pil_img
176
 
177
 
178
- def extract_text_from_cells(filename, cells):
179
- pil_img = PIL.Image.open(filename)
180
- pil_img, factor = resize(pil_img)
181
- #pil_img = remove_noise_and_smooth(pil_img)
182
- #display(pil_img)
 
 
 
 
 
 
 
 
 
183
  for cell in cells:
184
- bbox = [x * factor for x in cell['bbox']]
185
- cell_pil_img = pil_img.crop(bbox)
186
- #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
187
- #cell_pil_img = tess_prep(cell_pil_img)
188
- cell['text'] = pytess(cell_pil_img)
 
189
  return cells
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def cells_to_html(cells):
 
 
 
 
 
 
193
  for cell in cells:
194
- cell['column_nums'].sort()
195
- cell['row_nums'].sort()
196
- n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
197
- n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
198
- html_code = ''
199
- for r in range(n_rows):
200
- r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
201
- r_cells.sort(key=lambda x: x['column_nums'][0])
202
- r_html = ''
203
- for cell in r_cells:
204
- rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
205
- colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
206
- r_html += f'<td rowspan="{rowspan}" colspan="{colspan}">{escape(cell["text"])}</td>'
207
- html_code += f'<tr>{r_html}</tr>'
208
- html_code = '''<html>
209
- <head>
210
- <meta charset="UTF-8">
211
- <style>
212
- table, th, td {
213
- border: 1px solid black;
214
- font-size: 10px;
215
- }
216
- </style>
217
- </head>
218
- <body>
219
- <table frame="hsides" rules="groups" width="100%%">
220
- %s
221
- </table>
222
- </body>
223
- </html>''' % html_code
224
- soup = bs(html_code)
225
- html_code = soup.prettify()
226
- return html_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  def main():
@@ -234,7 +373,7 @@ def main():
234
 
235
  cols = st.columns((1, 1))
236
  cols[0].subheader("Input page")
237
- cols[1].subheader("Structure output")
238
 
239
  st.sidebar.title("Image upload")
240
  st.set_option('deprecation.showfileUploaderEncoding', False)
@@ -247,19 +386,39 @@ def main():
247
 
248
  else:
249
  print(filename)
250
-
251
- cols[0].image(filename)
252
-
253
- ocr_res = ocr(filename)
254
- structure_result = table_structure(filename)
255
- table_structures, cells, confidence_score = convert_stucture(ocr_res, filename, structure_result)
256
- visualize_cells(filename, cells, cols[1])
257
-
258
- cells = extract_text_from_cells(filename, cells)
259
- html_code = cells_to_html(cells)
260
-
261
- st.markdown("\nHTML output:")
262
- st.markdown(html_code, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
 
265
  if __name__ == '__main__':
 
1
  import streamlit as st
2
 
3
  import PIL
4
+ import cv2
5
  import numpy as np
6
+ import pandas as pd
7
  import torch
8
+ # import sys
9
+ # import json
10
+ from collections import OrderedDict, defaultdict
11
+ import xml.etree.ElementTree as ET
 
 
12
 
13
+ from paddleocr import PaddleOCR
14
  import pytesseract
15
  from pytesseract import Output
16
 
 
 
 
 
 
17
  import postprocess
18
 
19
 
20
+ ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
21
+ detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
22
  structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
23
  imgsz = 640
24
 
25
+ detection_class_names = ['table', 'table rotated']
26
  structure_class_names = [
27
  'table', 'table column', 'table row', 'table column header',
28
  'table projected row header', 'table spanning cell', 'no object'
 
40
 
41
 
42
  def PIL_to_cv(pil_img):
43
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
44
 
45
 
46
  def cv_to_PIL(cv_img):
47
  return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
48
 
49
 
50
+ def table_detection(pil_img):
51
+ image = PIL_to_cv(pil_img)
52
+ pred = detection_model(image, size=imgsz)
53
+ pred = pred.xywhn[0]
54
+ result = pred.cpu().numpy()
55
+ return result
56
+
57
+
58
+ def table_structure(pil_img):
59
  image = PIL_to_cv(pil_img)
60
  pred = structure_model(image, size=imgsz)
61
  pred = pred.xywhn[0]
 
63
  return result
64
 
65
 
66
+ def crop_image(pil_img, detection_result):
67
+ crop_images = []
68
+ image = PIL_to_cv(pil_img)
69
+ width = image.shape[1]
70
+ height = image.shape[0]
71
+ # print(width, height)
72
+ for i, result in enumerate(detection_result):
73
+ class_id = int(result[5])
74
+ score = float(result[4])
75
+ min_x = result[0]
76
+ min_y = result[1]
77
+ w = result[2]
78
+ h = result[3]
79
+
80
+ x1 = max(0, int((min_x - w / 2 - 0.02) * width))
81
+ y1 = max(0, int((min_y - h / 2 - 0.02) * height))
82
+ x2 = min(width, int((min_x + w / 2 + 0.02) * width))
83
+ y2 = min(height, int((min_y + h / 2 + 0.02) * height))
84
+ # print(x1, y1, x2, y2)
85
+ crop_image = image[y1:y2, x1:x2, :]
86
+ crop_images.append(cv_to_PIL(crop_image))
87
+
88
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
89
+
90
+ return crop_images, cv_to_PIL(image)
91
+
92
+
93
+ def ocr(pil_img):
94
+ image = PIL_to_cv(pil_img)
95
+ result = ocr_instance.ocr(image)
96
  ocr_res = []
97
+
98
+ for ps, (text, score) in result[0]:
99
+ x1 = min(p[0] for p in ps)
100
+ y1 = min(p[1] for p in ps)
101
+ x2 = max(p[0] for p in ps)
102
+ y2 = max(p[1] for p in ps)
103
+ word_info = {
104
+ 'bbox': [x1, y1, x2, y2],
105
+ 'text': text
106
+ }
107
+ ocr_res.append(word_info)
108
+
109
  return ocr_res
110
 
111
 
112
+ def convert_stucture(page_tokens, pil_img, structure_result):
 
113
  image = PIL_to_cv(pil_img)
114
 
115
  width = image.shape[1]
116
  height = image.shape[0]
117
  # print(width, height)
118
+
119
  bboxes = []
120
  scores = []
121
  labels = []
 
126
  min_y = result[1]
127
  w = result[2]
128
  h = result[3]
129
+
130
+ x1 = int((min_x - w / 2) * width)
131
+ y1 = int((min_y - h / 2) * height)
132
+ x2 = int((min_x + w / 2) * width)
133
+ y2 = int((min_y + h / 2) * height)
134
  # print(x1, y1, x2, y2)
135
 
136
  bboxes.append([x1, y1, x2, y2])
 
141
  for bbox, score, label in zip(bboxes, scores, labels):
142
  table_objects.append({'bbox': bbox, 'score': score, 'label': label})
143
  # print('table_objects:', table_objects)
144
+
145
  table = {'objects': table_objects, 'page_num': 0}
146
+
147
  table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
148
  if len(table_class_objects) > 1:
149
  table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
 
153
  table_bbox = (0,0,1000,1000)
154
  # print('table_class_objects:', table_class_objects)
155
  # print('table_bbox:', table_bbox)
156
+
157
  tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
158
  # print('tokens_in_table:', tokens_in_table)
159
+
160
  table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
161
+
162
  return table_structures, cells, confidence_score
163
 
164
 
165
+ def visualize_ocr(pil_img, ocr_result):
166
+ image = PIL_to_cv(pil_img)
167
+ for i, res in enumerate(ocr_result):
168
+ bbox = res['bbox']
169
+ x1 = int(bbox[0])
170
+ y1 = int(bbox[1])
171
+ x2 = int(bbox[2])
172
+ y2 = int(bbox[3])
173
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
174
+ return cv_to_PIL(image)
175
+
176
+
177
+ def visualize_structure(pil_img, structure_result):
178
+ image = PIL_to_cv(pil_img)
179
+ width = image.shape[1]
180
+ height = image.shape[0]
181
+ # print(width, height)
182
+ for i, result in enumerate(structure_result):
183
+ class_id = int(result[5])
184
+ score = float(result[4])
185
+ min_x = result[0]
186
+ min_y = result[1]
187
+ w = result[2]
188
+ h = result[3]
189
+
190
+ x1 = int((min_x - w / 2) * width)
191
+ y1 = int((min_y - h / 2) * height)
192
+ x2 = int((min_x + w / 2) * width)
193
+ y2 = int((min_y + h / 2) * height)
194
+ # print(x1, y1, x2, y2)
195
+
196
+ if score >= structure_class_map[structure_class_names[class_id]]:
197
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
198
+ #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
199
+
200
+ return cv_to_PIL(image)
201
+
202
+
203
+ def visualize_cells(pil_img, cells):
204
  image = PIL_to_cv(pil_img)
205
  for i, cell in enumerate(cells):
206
  bbox = cell['bbox']
 
209
  x2 = int(bbox[2])
210
  y2 = int(bbox[3])
211
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
212
+ return cv_to_PIL(image)
213
 
214
 
215
  def pytess(cell_pil_img):
 
244
  return pil_img
245
 
246
 
247
+ # def extract_text_from_cells(pil_img, cells):
248
+ # pil_img, factor = resize(pil_img)
249
+ # #pil_img = remove_noise_and_smooth(pil_img)
250
+ # #display(pil_img)
251
+ # for cell in cells:
252
+ # bbox = [x * factor for x in cell['bbox']]
253
+ # cell_pil_img = pil_img.crop(bbox)
254
+ # #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
255
+ # #cell_pil_img = tess_prep(cell_pil_img)
256
+ # cell['cell text'] = pytess(cell_pil_img)
257
+ # return cells
258
+
259
+
260
+ def extract_text_from_cells(cells, sep=' '):
261
  for cell in cells:
262
+ spans = cell['spans']
263
+ text = ''
264
+ for span in spans:
265
+ if 'text' in span:
266
+ text += span['text'] + sep
267
+ cell['cell_text'] = text
268
  return cells
269
 
270
 
271
+ def cells_to_csv(cells):
272
+ if len(cells) > 0:
273
+ num_columns = max([max(cell['column_nums']) for cell in cells]) + 1
274
+ num_rows = max([max(cell['row_nums']) for cell in cells]) + 1
275
+ else:
276
+ return
277
+
278
+ header_cells = [cell for cell in cells if cell['header']]
279
+ if len(header_cells) > 0:
280
+ max_header_row = max([max(cell['row_nums']) for cell in header_cells])
281
+ else:
282
+ max_header_row = -1
283
+
284
+ table_array = np.empty([num_rows, num_columns], dtype="object")
285
+ if len(cells) > 0:
286
+ for cell in cells:
287
+ for row_num in cell['row_nums']:
288
+ for column_num in cell['column_nums']:
289
+ table_array[row_num, column_num] = cell["cell_text"]
290
+
291
+ header = table_array[:max_header_row+1,:]
292
+ flattened_header = []
293
+ for col in header.transpose():
294
+ flattened_header.append(' | '.join(OrderedDict.fromkeys(col)))
295
+ df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header)
296
+
297
+ return df, df.to_csv(index=None)
298
+
299
+
300
  def cells_to_html(cells):
301
+ cells = sorted(cells, key=lambda k: min(k['column_nums']))
302
+ cells = sorted(cells, key=lambda k: min(k['row_nums']))
303
+
304
+ table = ET.Element("table")
305
+ current_row = -1
306
+
307
  for cell in cells:
308
+ this_row = min(cell['row_nums'])
309
+
310
+ attrib = {}
311
+ colspan = len(cell['column_nums'])
312
+ if colspan > 1:
313
+ attrib['colspan'] = str(colspan)
314
+ rowspan = len(cell['row_nums'])
315
+ if rowspan > 1:
316
+ attrib['rowspan'] = str(rowspan)
317
+ if this_row > current_row:
318
+ current_row = this_row
319
+ if cell['header']:
320
+ cell_tag = "th"
321
+ row = ET.SubElement(table, "thead")
322
+ else:
323
+ cell_tag = "td"
324
+ row = ET.SubElement(table, "tr")
325
+ tcell = ET.SubElement(row, cell_tag, attrib=attrib)
326
+ tcell.text = cell['cell_text']
327
+
328
+ return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))
329
+
330
+
331
+ # def cells_to_html(cells):
332
+ # for cell in cells:
333
+ # cell['column_nums'].sort()
334
+ # cell['row_nums'].sort()
335
+ # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
336
+ # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
337
+ # html_code = ''
338
+ # for r in range(n_rows):
339
+ # r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
340
+ # r_cells.sort(key=lambda x: x['column_nums'][0])
341
+ # r_html = ''
342
+ # for cell in r_cells:
343
+ # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
344
+ # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
345
+ # r_html += f'<td rowspan="{rowspan}" colspan="{colspan}">{escape(cell["text"])}</td>'
346
+ # html_code += f'<tr>{r_html}</tr>'
347
+ # html_code = '''<html>
348
+ # <head>
349
+ # <meta charset="UTF-8">
350
+ # <style>
351
+ # table, th, td {
352
+ # border: 1px solid black;
353
+ # font-size: 10px;
354
+ # }
355
+ # </style>
356
+ # </head>
357
+ # <body>
358
+ # <table frame="hsides" rules="groups" width="100%%">
359
+ # %s
360
+ # </table>
361
+ # </body>
362
+ # </html>''' % html_code
363
+ # soup = bs(html_code)
364
+ # html_code = soup.prettify()
365
+ # return html_code
366
 
367
 
368
  def main():
 
373
 
374
  cols = st.columns((1, 1))
375
  cols[0].subheader("Input page")
376
+ cols[1].subheader("Table(s) detected")
377
 
378
  st.sidebar.title("Image upload")
379
  st.set_option('deprecation.showfileUploaderEncoding', False)
 
386
 
387
  else:
388
  print(filename)
389
+ pil_img = PIL.Image.open(filename)
390
+
391
+ detection_result = table_detection(pil_img)
392
+ crop_images, vis_det_img = crop_image(pil_img, detection_result)
393
+ cols[0].image(vis_det_img)
394
+
395
+ str_cols = st.columns((len(crop_images), ) * 6)
396
+ str_cols[0].subheader("Table image")
397
+ str_cols[1].subheader("OCR result")
398
+ str_cols[2].subheader("Structure result")
399
+ str_cols[3].subheader("Cells result")
400
+ str_cols[4].subheader("HTML result")
401
+ str_cols[5].subheader("CSV result")
402
+
403
+ for img in crop_images:
404
+ ocr_result = ocr(img)
405
+ structure_result = table_structure(img)
406
+ table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
407
+ cells = extract_text_from_cells(cells)
408
+ html_result = cells_to_html(cells)
409
+ df, csv_result = cells_to_csv(cells)
410
+
411
+ vis_ocr_img = visualize_ocr(img, ocr_result)
412
+ vis_str_img = visualize_structure(img, structure_result)
413
+ vis_cells_img = visualize_cells(img, cells)
414
+
415
+ str_cols[0].image(img)
416
+ str_cols[1].image(vis_ocr_img)
417
+ str_cols[2].image(vis_str_img)
418
+ str_cols[3].image(vis_cells_img)
419
+ str_cols[4].markdown(html_result, unsafe_allow_html=True)
420
+ str_cols[5].dataframe(df)
421
+ str_cols[5].download_button("Download table", csv_result, "file.csv", "text/csv", key='download-csv')
422
 
423
 
424
  if __name__ == '__main__':
postprocess.py CHANGED
@@ -38,9 +38,9 @@ def iou(bbox1, bbox2):
38
  intersection = Rect(bbox1).intersect(bbox2)
39
  union = Rect(bbox1).include_rect(bbox2)
40
 
41
- union_area = union.get_area() # getArea()
42
  if union_area > 0:
43
- return intersection.get_area() / union.get_area() # .getArea()
44
 
45
  return 0
46
 
@@ -51,9 +51,9 @@ def iob(bbox1, bbox2):
51
  """
52
  intersection = Rect(bbox1).intersect(bbox2)
53
 
54
- bbox1_area = Rect(bbox1).get_area() # .getArea()
55
  if bbox1_area > 0:
56
- return intersection.get_area() / bbox1_area # getArea()
57
 
58
  return 0
59
 
@@ -144,36 +144,36 @@ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table,
144
  return table_structures
145
 
146
 
147
- def refine_rows(rows, page_spans, score_threshold):
148
  """
149
  Apply operations to the detected rows, such as
150
  thresholding, NMS, and alignment.
151
  """
152
-
153
- #MODIFY
154
- rows = [obj for obj in rows if obj['score'] >= score_threshold or obj['header']]
155
- ###
156
-
157
- rows = nms_by_containment(rows, page_spans, overlap_threshold=0.5)
158
- # remove_objects_without_content(page_spans, rows) # TODO
159
  if len(rows) > 1:
160
  rows = sort_objects_top_to_bottom(rows)
161
 
162
  return rows
163
 
164
 
165
- def refine_columns(columns, page_spans, score_threshold):
166
  """
167
  Apply operations to the detected columns, such as
168
  thresholding, NMS, and alignment.
169
  """
170
 
171
- #MODIFY
172
- columns = [obj for obj in columns if obj['score'] >= score_threshold]
173
- ###
174
-
175
- columns = nms_by_containment(columns, page_spans, overlap_threshold=0.5)
176
- # remove_objects_without_content(page_spans, columns) # TODO
177
  if len(columns) > 1:
178
  columns = sort_objects_left_to_right(columns)
179
 
@@ -222,10 +222,10 @@ def slot_into_containers(container_objects, package_objects, overlap_threshold=0
222
  for package_num, package in enumerate(package_objects):
223
  match_scores = []
224
  package_rect = Rect(package['bbox'])
225
- package_area = package_rect.get_area() # getArea()
226
  for container_num, container in enumerate(container_objects):
227
  container_rect = Rect(container['bbox'])
228
- intersect_area = container_rect.intersect(package['bbox']).get_area() # getArea()
229
  overlap_fraction = intersect_area / package_area
230
  match_scores.append({'container': container, 'container_num': container_num, 'score': overlap_fraction})
231
 
@@ -298,10 +298,10 @@ def overlaps(bbox1, bbox2, threshold=0.5):
298
  Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
299
  """
300
  rect1 = Rect(list(bbox1))
301
- area1 = rect1.get_area() # .getArea()
302
  if area1 == 0:
303
  return False
304
- return rect1.intersect(list(bbox2)).get_area()/area1 >= threshold # getArea()
305
 
306
 
307
  def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
@@ -317,6 +317,8 @@ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscr
317
 
318
  if remove_integer_superscripts:
319
  for span in spans:
 
 
320
  flags = span['flags']
321
  if flags & 2**0: # superscript flag
322
  if is_int(span['text']):
@@ -438,7 +440,7 @@ def refine_table_structures(table_bbox, table_structures, page_spans, class_thre
438
  return table_structures
439
 
440
 
441
- def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_metric="score", keep_higher=True):
442
  """
443
  A customizable version of non-maxima suppression (NMS).
444
 
@@ -448,28 +450,24 @@ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_me
448
  objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
449
  match_criteria: how to measure how much two objects "overlap"
450
  match_threshold: the cutoff for determining that overlap requires suppression of one object
451
- keep_metric: which metric to use to determine the object to keep
452
  keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
453
  """
454
  if len(objects) == 0:
455
  return []
456
 
457
- if keep_metric=="score":
458
- objects = sort_objects_by_score(objects, reverse=keep_higher)
459
- elif keep_metric=="area":
460
- objects = sort_objects_by_area(objects, reverse=keep_higher)
461
 
462
  num_objects = len(objects)
463
  suppression = [False for obj in objects]
464
 
465
  for object2_num in range(1, num_objects):
466
  object2_rect = Rect(objects[object2_num]['bbox'])
467
- object2_area = object2_rect.get_area() # .getArea()
468
  for object1_num in range(object2_num):
469
  if not suppression[object1_num]:
470
  object1_rect = Rect(objects[object1_num]['bbox'])
471
- object1_area = object1_rect.get_area() # .getArea()
472
- intersect_area = object1_rect.intersect(object2_rect).get_area() # .getArea()
473
  try:
474
  if match_criteria=="object1_overlap":
475
  metric = intersect_area / object1_area
@@ -719,8 +717,8 @@ def table_structure_to_cells(table_structures, table_spans, table_bbox):
719
  cell['subcell'] = False
720
  for supercell in supercells:
721
  supercell_rect = Rect(list(supercell['bbox']))
722
- if (supercell_rect.intersect(cell_rect).get_area() # .getArea()
723
- / cell_rect.get_area()) > 0.5: # getArea()
724
  cell['subcell'] = True
725
  break
726
 
@@ -740,8 +738,8 @@ def table_structure_to_cells(table_structures, table_spans, table_bbox):
740
  header = True
741
  for subcell in subcells:
742
  subcell_rect = Rect(list(subcell['bbox']))
743
- subcell_rect_area = subcell_rect.get_area() # .getArea()
744
- if (subcell_rect.intersect(supercell_rect).get_area() # .getArea()
745
  / subcell_rect_area) > 0.5:
746
  if cell_rect is None:
747
  cell_rect = Rect(list(subcell['bbox']))
@@ -838,7 +836,7 @@ def table_structure_to_cells(table_structures, table_spans, table_bbox):
838
  for column_num in cell['column_nums']:
839
  column_rect.include_rect(list(columns[column_num]['bbox']))
840
  cell_rect = row_rect.intersect(column_rect)
841
- if cell_rect.get_area() > 0: # getArea()
842
  cell['bbox'] = list(cell_rect)
843
  pass
844
 
 
38
  intersection = Rect(bbox1).intersect(bbox2)
39
  union = Rect(bbox1).include_rect(bbox2)
40
 
41
+ union_area = union.get_area()
42
  if union_area > 0:
43
+ return intersection.get_area() / union.get_area()
44
 
45
  return 0
46
 
 
51
  """
52
  intersection = Rect(bbox1).intersect(bbox2)
53
 
54
+ bbox1_area = Rect(bbox1).get_area()
55
  if bbox1_area > 0:
56
+ return intersection.get_area() / bbox1_area
57
 
58
  return 0
59
 
 
144
  return table_structures
145
 
146
 
147
+ def refine_rows(rows, tokens, score_threshold):
148
  """
149
  Apply operations to the detected rows, such as
150
  thresholding, NMS, and alignment.
151
  """
152
+
153
+ if len(tokens) > 0:
154
+ rows = nms_by_containment(rows, tokens, overlap_threshold=0.5)
155
+ # remove_objects_without_content(tokens, rows) # TODO
156
+ else:
157
+ rows = nms(rows, match_criteria="object2_overlap",
158
+ match_threshold=0.5, keep_higher=True)
159
  if len(rows) > 1:
160
  rows = sort_objects_top_to_bottom(rows)
161
 
162
  return rows
163
 
164
 
165
+ def refine_columns(columns, tokens, score_threshold):
166
  """
167
  Apply operations to the detected columns, such as
168
  thresholding, NMS, and alignment.
169
  """
170
 
171
+ if len(tokens) > 0:
172
+ columns = nms_by_containment(columns, tokens, overlap_threshold=0.5)
173
+ # remove_objects_without_content(tokens, columns) # TODO
174
+ else:
175
+ columns = nms(columns, match_criteria="object2_overlap",
176
+ match_threshold=0.25, keep_higher=True)
177
  if len(columns) > 1:
178
  columns = sort_objects_left_to_right(columns)
179
 
 
222
  for package_num, package in enumerate(package_objects):
223
  match_scores = []
224
  package_rect = Rect(package['bbox'])
225
+ package_area = package_rect.get_area()
226
  for container_num, container in enumerate(container_objects):
227
  container_rect = Rect(container['bbox'])
228
+ intersect_area = container_rect.intersect(package['bbox']).get_area()
229
  overlap_fraction = intersect_area / package_area
230
  match_scores.append({'container': container, 'container_num': container_num, 'score': overlap_fraction})
231
 
 
298
  Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
299
  """
300
  rect1 = Rect(list(bbox1))
301
+ area1 = rect1.get_area()
302
  if area1 == 0:
303
  return False
304
+ return rect1.intersect(list(bbox2)).get_area()/area1 >= threshold
305
 
306
 
307
  def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
 
317
 
318
  if remove_integer_superscripts:
319
  for span in spans:
320
+ if not 'flags' in span:
321
+ continue
322
  flags = span['flags']
323
  if flags & 2**0: # superscript flag
324
  if is_int(span['text']):
 
440
  return table_structures
441
 
442
 
443
+ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True):
444
  """
445
  A customizable version of non-maxima suppression (NMS).
446
 
 
450
  objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
451
  match_criteria: how to measure how much two objects "overlap"
452
  match_threshold: the cutoff for determining that overlap requires suppression of one object
 
453
  keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
454
  """
455
  if len(objects) == 0:
456
  return []
457
 
458
+ objects = sort_objects_by_score(objects, reverse=keep_higher)
 
 
 
459
 
460
  num_objects = len(objects)
461
  suppression = [False for obj in objects]
462
 
463
  for object2_num in range(1, num_objects):
464
  object2_rect = Rect(objects[object2_num]['bbox'])
465
+ object2_area = object2_rect.get_area()
466
  for object1_num in range(object2_num):
467
  if not suppression[object1_num]:
468
  object1_rect = Rect(objects[object1_num]['bbox'])
469
+ object1_area = object1_rect.get_area()
470
+ intersect_area = object1_rect.intersect(object2_rect).get_area()
471
  try:
472
  if match_criteria=="object1_overlap":
473
  metric = intersect_area / object1_area
 
717
  cell['subcell'] = False
718
  for supercell in supercells:
719
  supercell_rect = Rect(list(supercell['bbox']))
720
+ if (supercell_rect.intersect(cell_rect).get_area()
721
+ / cell_rect.get_area()) > 0.5:
722
  cell['subcell'] = True
723
  break
724
 
 
738
  header = True
739
  for subcell in subcells:
740
  subcell_rect = Rect(list(subcell['bbox']))
741
+ subcell_rect_area = subcell_rect.get_area()
742
+ if (subcell_rect.intersect(supercell_rect).get_area()
743
  / subcell_rect_area) > 0.5:
744
  if cell_rect is None:
745
  cell_rect = Rect(list(subcell['bbox']))
 
836
  for column_num in cell['column_nums']:
837
  column_rect.include_rect(list(columns[column_num]['bbox']))
838
  cell_rect = row_rect.intersect(column_rect)
839
+ if cell_rect.get_area() > 0:
840
  cell['bbox'] = list(cell_rect)
841
  pass
842
 
requirements.txt CHANGED
@@ -1,19 +1,78 @@
1
- -e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf]
2
- streamlit>=0.65.0
3
- PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12,!=1.19.5
4
- tf2onnx==1.13.0
5
- Pillow==9.2.0
6
- pytesseract==0.3.10
7
- torch==1.12.0
8
- torchvision==0.13.0
9
- beautifulsoup4==4.11.1
10
- psutil
11
- numpy>=1.21.6
12
- scipy>=1.7.3
13
- thop>=0.1.1
14
- tqdm>=4.64.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  gitpython>=3.1.30
16
- matplotlib>=3.5.3
17
- pandas>=1.3.5
18
- seaborn>=0.12.0
19
- setuptools>=65.5.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PaddleOCR
2
+ shapely
3
+ scikit-image
4
+ imgaug
5
+ pyclipper
6
+ lmdb
7
+ tqdm
8
+ numpy
9
+ visualdl
10
+ rapidfuzz
11
+ opencv-python==4.6.0.66
12
+ opencv-contrib-python==4.6.0.66
13
+ cython
14
+ lxml
15
+ premailer
16
+ openpyxl
17
+ attrdict
18
+ Polygon3
19
+ lanms-neo==1.0.2
20
+ PyMuPDF<1.21.0
21
+ paddleocr
22
+ paddlepaddle
23
+ paddlehub
24
+
25
+ # YOLOv5
26
+ # YOLOv5 requirements
27
+ # Usage: pip install -r requirements.txt
28
+
29
+ # Base ------------------------------------------------------------------------
30
  gitpython>=3.1.30
31
+ matplotlib>=3.2.2
32
+ numpy>=1.18.5
33
+ opencv-python>=4.1.1
34
+ Pillow>=7.1.2
35
+ psutil # system resources
36
+ PyYAML>=5.3.1
37
+ requests>=2.23.0
38
+ scipy>=1.4.1
39
+ thop>=0.1.1 # FLOPs computation
40
+ torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended)
41
+ torchvision>=0.8.1
42
+ tqdm>=4.64.0
43
+ # protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
44
+
45
+ # Logging ---------------------------------------------------------------------
46
+ tensorboard>=2.4.1
47
+ # clearml>=1.2.0
48
+ # comet
49
+
50
+ # Plotting --------------------------------------------------------------------
51
+ pandas>=1.1.4
52
+ seaborn>=0.11.0
53
+
54
+ # Export ----------------------------------------------------------------------
55
+ # coremltools>=6.0 # CoreML export
56
+ # onnx>=1.12.0 # ONNX export
57
+ # onnx-simplifier>=0.4.1 # ONNX simplifier
58
+ # nvidia-pyindex # TensorRT export
59
+ # nvidia-tensorrt # TensorRT export
60
+ # scikit-learn<=1.1.2 # CoreML quantization
61
+ # tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
62
+ # tensorflowjs>=3.9.0 # TF.js export
63
+ # openvino-dev # OpenVINO export
64
+
65
+ # Deploy ----------------------------------------------------------------------
66
+ setuptools>=65.5.1 # Snyk vulnerability fix
67
+ # tritonclient[all]~=2.24.0
68
+
69
+ # Extras ----------------------------------------------------------------------
70
+ # ipython # interactive notebook
71
+ # mss # screenshots
72
+ # albumentations>=1.0.3
73
+ # pycocotools>=2.0.6 # COCO mAP
74
+ # ultralytics # HUB https://hub.ultralytics.com
75
+
76
+ # Other
77
+ pytesseract==0.3.10
78
+ # beautifulsoup4==4.11.1
weights/detection_wts.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32127c7362c16c5839cb95c942cbc9ad1412fd953eb4b0b93758a49f01e312cb
3
+ size 14397685