EC2 Default User commited on
Commit
12b050a
1 Parent(s): 5d7adbb

init commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ttc filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
SimSong.ttc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff39c3a0d87e3851297b35826489032448851df948e41fc56b2fe39c38d58e3
3
+ size 36859516
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+ from matplotlib.patches import Patch
4
+ import io
5
+ import cv2
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import numpy as np
8
+ import csv
9
+ import pandas as pd
10
+
11
+ from ultralytics import YOLO
12
+ import torch
13
+
14
+ from paddleocr import PaddleOCR
15
+ import postprocess
16
+
17
+ import gradio as gr
18
+
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device)
22
+ structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device)
23
+ ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920) # TODO use large det_limit_side_len to get better OCR result
24
+
25
+ detection_class_names = ['table', 'table rotated']
26
+ structure_class_names = [
27
+ 'table', 'table column', 'table row', 'table column header',
28
+ 'table projected row header', 'table spanning cell', 'no object'
29
+ ]
30
+ structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
31
+ structure_class_thresholds = {
32
+ "table": 0.5,
33
+ "table column": 0.5,
34
+ "table row": 0.5,
35
+ "table column header": 0.5,
36
+ "table projected row header": 0.5,
37
+ "table spanning cell": 0.5,
38
+ "no object": 10
39
+ }
40
+
41
+
42
+ def table_detection(image):
43
+ imgsz = 800
44
+ pred = detection_model.predict(image, imgsz=imgsz)
45
+ pred = pred[0].boxes
46
+ result = pred.cpu().numpy()
47
+ result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
48
+ return result_list
49
+
50
+
51
+ def table_structure(image):
52
+ imgsz = 1024
53
+ pred = structure_model.predict(image, imgsz=imgsz)
54
+ pred = pred[0].boxes
55
+ result = pred.cpu().numpy()
56
+ result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
57
+ return result_list
58
+
59
+
60
+ def crop_image(image, detection_result):
61
+ # crop_filenames = []
62
+ width = image.shape[1]
63
+ height = image.shape[0]
64
+ # print(width, height)
65
+ for i, result in enumerate(detection_result[:1]): # TODO only return first detected table
66
+ class_id = int(result[5])
67
+ score = float(result[4])
68
+ min_x = result[0]
69
+ min_y = result[1]
70
+ w = result[2]
71
+ h = result[3]
72
+
73
+ # x1 = max(0, int((min_x-w/2-0.02)*width)) # TODO expand 2%
74
+ # y1 = max(0, int((min_y-h/2-0.02)*height)) # TODO expand 2%
75
+ # x2 = min(width, int((min_x+w/2+0.02)*width)) # TODO expand 2%
76
+ # y2 = min(height, int((min_y+h/2+0.02)*height)) # TODO expand 2%
77
+ x1 = max(0, int((min_x-w/2)*width)-10) # TODO expand 10px
78
+ y1 = max(0, int((min_y-h/2)*height)-10) # TODO expand 10px
79
+ x2 = min(width, int((min_x+w/2)*width)+10) # TODO expand 10px
80
+ y2 = min(height, int((min_y+h/2)*height)+10) # TODO expand 10px
81
+ # print(x1, y1, x2, y2)
82
+ crop_image = image[y1:y2, x1:x2, :]
83
+ # crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:]
84
+ # crop_filenames.append(crop_filename)
85
+ # cv2.imwrite(crop_filename, crop_image)
86
+ return crop_image
87
+
88
+
89
+ def convert_stucture(ocr_result, image, structure_result):
90
+ width = image.shape[1]
91
+ height = image.shape[0]
92
+ # print(width, height)
93
+
94
+ bboxes = []
95
+ scores = []
96
+ labels = []
97
+ for i, result in enumerate(structure_result):
98
+ class_id = int(result[5])
99
+ score = float(result[4])
100
+ min_x = result[0]
101
+ min_y = result[1]
102
+ w = result[2]
103
+ h = result[3]
104
+
105
+ x1 = int((min_x-w/2)*width)
106
+ y1 = int((min_y-h/2)*height)
107
+ x2 = int((min_x+w/2)*width)
108
+ y2 = int((min_y+h/2)*height)
109
+ # print(x1, y1, x2, y2)
110
+
111
+ bboxes.append([x1, y1, x2, y2])
112
+ scores.append(score)
113
+ labels.append(class_id)
114
+
115
+ table_objects = []
116
+ for bbox, score, label in zip(bboxes, scores, labels):
117
+ table_objects.append({'bbox': bbox, 'score': score, 'label': label})
118
+ # print('table_objects:', table_objects)
119
+
120
+ table = {'objects': table_objects, 'page_num': 0}
121
+
122
+ table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
123
+ if len(table_class_objects) > 1:
124
+ table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
125
+ try:
126
+ table_bbox = list(table_class_objects[0]['bbox'])
127
+ except:
128
+ table_bbox = (0,0,1000,1000)
129
+ # print('table_class_objects:', table_class_objects)
130
+ # print('table_bbox:', table_bbox)
131
+
132
+ page_tokens = ocr_result
133
+ tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
134
+ # print('tokens_in_table:', tokens_in_table)
135
+
136
+ table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
137
+
138
+ return table_structures, cells, confidence_score
139
+
140
+
141
+ def visualize_cells(image, table_structures, cells):
142
+ width = image.shape[1]
143
+ height = image.shape[0]
144
+ # print(width, height)
145
+ empty_image = np.zeros((height, width, 3), np.uint8)
146
+ empty_image.fill(255)
147
+ empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB))
148
+ draw = ImageDraw.Draw(empty_image)
149
+ fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8")
150
+
151
+ num_cols = len(table_structures['columns'])
152
+ num_rows = len(table_structures['rows'])
153
+ data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)]
154
+ for i, cell in enumerate(cells):
155
+ bbox = cell['bbox']
156
+ x1 = int(bbox[0])
157
+ y1 = int(bbox[1])
158
+ x2 = int(bbox[2])
159
+ y2 = int(bbox[3])
160
+ col_num = cell['column_nums'][0]
161
+ row_num = cell['row_nums'][0]
162
+ spans = cell['spans']
163
+ text = ''
164
+ for span in spans:
165
+ if 'text' in span:
166
+ text += span['text']
167
+ data_rows[row_num][col_num] = text
168
+
169
+ # print('text:', text)
170
+ text_len = len(text)
171
+ # print('text_len:', text_len)
172
+ cell_width = x2-x1
173
+ # print('cell_width:', cell_width)
174
+ num_per_line = cell_width//10
175
+ # print('num_per_line:', num_per_line)
176
+ if num_per_line != 0:
177
+ line_num = text_len//num_per_line
178
+ else:
179
+ line_num = 0
180
+ # print('line_num:', line_num)
181
+ new_text = text[:num_per_line]+'\n'
182
+ for j in range(line_num):
183
+ new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n'
184
+ # print('new_text:', new_text)
185
+ text = new_text
186
+
187
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0))
188
+ cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
189
+
190
+ # cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255))
191
+ # cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
192
+ # cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
193
+ draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0))
194
+ draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle)
195
+ draw.text((x1, y1), text, (0,0,255), font=fontStyle)
196
+
197
+ df = pd.DataFrame(data_rows[1:], columns=data_rows[0])
198
+ return image, df, df.to_json()
199
+
200
+
201
+ def ocr(image):
202
+ result = ocr_model.ocr(image, cls=True)
203
+ result = result[0]
204
+ new_result = []
205
+ if result is not None:
206
+ bounding_boxes = [line[0] for line in result]
207
+ txts = [line[1][0] for line in result]
208
+ scores = [line[1][1] for line in result]
209
+ # print('txts:', txts)
210
+ # print('scores:', scores)
211
+ # print('bounding_boxes:', bounding_boxes)
212
+ for label, bbox in zip(txts, bounding_boxes):
213
+ new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label})
214
+
215
+ return new_result
216
+
217
+
218
+ def detect_and_crop_table(image):
219
+ detection_result = table_detection(image)
220
+ # print('detection_result:', detection_result)
221
+ cropped_table = crop_image(image, detection_result)
222
+
223
+ return cropped_table
224
+
225
+
226
+ def recognize_table(image, ocr_result):
227
+ structure_result = table_structure(image)
228
+ print('structure_result:', structure_result)
229
+ table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result)
230
+ print('table_structures:', table_structures)
231
+ print('cells:', cells)
232
+ print('confidence_score:', confidence_score)
233
+ image, df, data = visualize_cells(image, table_structures, cells)
234
+
235
+ return image, df, data
236
+
237
+
238
+ def process_pdf(image):
239
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
240
+
241
+ cropped_table = detect_and_crop_table(image)
242
+
243
+ ocr_result = ocr(cropped_table)
244
+ # print('ocr_result:', ocr_result)
245
+
246
+ image, df, data = recognize_table(cropped_table, ocr_result)
247
+ print('df:', df)
248
+
249
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
250
+
251
+ return image, df, data
252
+
253
+
254
+ title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)."
255
+ description = """Demo for table extraction with the Table Structure Recognition (Yolov8)."""
256
+ examples = [['image.png'], ['mistral_paper.png']]
257
+
258
+ app = gr.Interface(fn=process_pdf,
259
+ inputs=gr.Image(type="numpy"),
260
+ outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
261
+ title=title,
262
+ description=description,
263
+ examples=examples)
264
+ app.queue()
265
+ # app.launch(debug=True, share=True)
266
+ app.launch()
clip_paper.png ADDED

Git LFS Details

  • SHA256: 90e0cca54433cbf6f85ae3a9acd8d6c7bd30beaac75b6ed607fc20c86834806a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
image.png ADDED

Git LFS Details

  • SHA256: ca1ccba9032a8f41511c0770111dbc7f363674caf0ae474231a443821312c6ec
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
mistral_paper.png ADDED

Git LFS Details

  • SHA256: d40877f122327935f575bca94bb8e74bec1fe9d2355049dd8e6e50cd53114474
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
postprocess.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2021 Microsoft Corporation
3
+ """
4
+ from collections import defaultdict
5
+
6
+ from fitz import Rect
7
+
8
+
9
+ def apply_threshold(objects, threshold):
10
+ """
11
+ Filter out objects below a certain score.
12
+ """
13
+ return [obj for obj in objects if obj['score'] >= threshold]
14
+
15
+
16
+ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
17
+ """
18
+ Filter out bounding boxes whose confidence is below the confidence threshold for
19
+ its associated class label.
20
+ """
21
+ # Apply class-specific thresholds
22
+ indices_above_threshold = [idx for idx, (score, label) in enumerate(zip(scores, labels))
23
+ if score >= class_thresholds[
24
+ class_names[label]
25
+ ]
26
+ ]
27
+ bboxes = [bboxes[idx] for idx in indices_above_threshold]
28
+ scores = [scores[idx] for idx in indices_above_threshold]
29
+ labels = [labels[idx] for idx in indices_above_threshold]
30
+
31
+ return bboxes, scores, labels
32
+
33
+
34
+ def iou(bbox1, bbox2):
35
+ """
36
+ Compute the intersection-over-union of two bounding boxes.
37
+ """
38
+ intersection = Rect(bbox1).intersect(bbox2)
39
+ union = Rect(bbox1).include_rect(bbox2)
40
+
41
+ union_area = union.get_area() # getArea()
42
+ if union_area > 0:
43
+ return intersection.get_area() / union.get_area() # .getArea()
44
+
45
+ return 0
46
+
47
+
48
+ def iob(bbox1, bbox2):
49
+ """
50
+ Compute the intersection area over box area, for bbox1.
51
+ """
52
+ intersection = Rect(bbox1).intersect(bbox2)
53
+
54
+ bbox1_area = Rect(bbox1).get_area() # .getArea()
55
+ if bbox1_area > 0:
56
+ return intersection.get_area() / bbox1_area # getArea()
57
+
58
+ return 0
59
+
60
+
61
+ def objects_to_cells(table, objects_in_table, tokens_in_table, class_map, class_thresholds):
62
+ """
63
+ Process the bounding boxes produced by the table structure recognition model
64
+ and the token/word/span bounding boxes into table cells.
65
+
66
+ Also return a confidence score based on how well the text was able to be
67
+ uniquely slotted into the cells detected by the table model.
68
+ """
69
+
70
+ table_structures = objects_to_table_structures(table, objects_in_table, tokens_in_table, class_map,
71
+ class_thresholds)
72
+
73
+ # Check for a valid table
74
+ if len(table_structures['columns']) < 1 or len(table_structures['rows']) < 1:
75
+ cells = []#None
76
+ confidence_score = 0
77
+ else:
78
+ cells, confidence_score = table_structure_to_cells(table_structures, tokens_in_table, table['bbox'])
79
+
80
+ return table_structures, cells, confidence_score
81
+
82
+
83
+ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, class_names, class_thresholds):
84
+ """
85
+ Process the bounding boxes produced by the table structure recognition model into
86
+ a *consistent* set of table structures (rows, columns, supercells, headers).
87
+ This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment
88
+ conditions (for example: rows should all have the same width, etc.).
89
+ """
90
+
91
+ page_num = table_object['page_num']
92
+
93
+ table_structures = {}
94
+
95
+ columns = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column']
96
+ rows = [obj for obj in objects_in_table if class_names[obj['label']] == 'table row']
97
+ headers = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column header']
98
+ supercells = [obj for obj in objects_in_table if class_names[obj['label']] == 'table spanning cell']
99
+ for obj in supercells:
100
+ obj['subheader'] = False
101
+ subheaders = [obj for obj in objects_in_table if class_names[obj['label']] == 'table projected row header']
102
+ for obj in subheaders:
103
+ obj['subheader'] = True
104
+ supercells += subheaders
105
+ for obj in rows:
106
+ obj['header'] = False
107
+ for header_obj in headers:
108
+ if iob(obj['bbox'], header_obj['bbox']) >= 0.5:
109
+ obj['header'] = True
110
+
111
+ for row in rows:
112
+ row['page'] = page_num
113
+
114
+ for column in columns:
115
+ column['page'] = page_num
116
+
117
+ #Refine table structures
118
+ rows = refine_rows(rows, tokens_in_table, class_thresholds['table row'])
119
+ columns = refine_columns(columns, tokens_in_table, class_thresholds['table column'])
120
+
121
+ # Shrink table bbox to just the total height of the rows
122
+ # and the total width of the columns
123
+ row_rect = Rect()
124
+ for obj in rows:
125
+ row_rect.include_rect(obj['bbox'])
126
+ column_rect = Rect()
127
+ for obj in columns:
128
+ column_rect.include_rect(obj['bbox'])
129
+ table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]]
130
+ table_object['bbox'] = table_object['row_column_bbox']
131
+
132
+ # Process the rows and columns into a complete segmented table
133
+ columns = align_columns(columns, table_object['row_column_bbox'])
134
+ rows = align_rows(rows, table_object['row_column_bbox'])
135
+
136
+ table_structures['rows'] = rows
137
+ table_structures['columns'] = columns
138
+ table_structures['headers'] = headers
139
+ table_structures['supercells'] = supercells
140
+
141
+ if len(rows) > 0 and len(columns) > 1:
142
+ table_structures = refine_table_structures(table_object['bbox'], table_structures, tokens_in_table, class_thresholds)
143
+
144
+ return table_structures
145
+
146
+
147
+ def refine_rows(rows, page_spans, score_threshold):
148
+ """
149
+ Apply operations to the detected rows, such as
150
+ thresholding, NMS, and alignment.
151
+ """
152
+
153
+ rows = nms_by_containment(rows, page_spans, overlap_threshold=0.5)
154
+ # remove_objects_without_content(page_spans, rows) # TODO
155
+ if len(rows) > 1:
156
+ rows = sort_objects_top_to_bottom(rows)
157
+
158
+ return rows
159
+
160
+
161
+ def refine_columns(columns, page_spans, score_threshold):
162
+ """
163
+ Apply operations to the detected columns, such as
164
+ thresholding, NMS, and alignment.
165
+ """
166
+
167
+ columns = nms_by_containment(columns, page_spans, overlap_threshold=0.5)
168
+ # remove_objects_without_content(page_spans, columns) # TODO
169
+ if len(columns) > 1:
170
+ columns = sort_objects_left_to_right(columns)
171
+
172
+ return columns
173
+
174
+
175
+ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5):
176
+ """
177
+ Non-maxima suppression (NMS) of objects based on shared containment of other objects.
178
+ """
179
+ container_objects = sort_objects_by_score(container_objects)
180
+ num_objects = len(container_objects)
181
+ suppression = [False for obj in container_objects]
182
+
183
+ packages_by_container, _, _ = slot_into_containers(container_objects, package_objects, overlap_threshold=overlap_threshold,
184
+ unique_assignment=True, forced_assignment=False)
185
+
186
+ for object2_num in range(1, num_objects):
187
+ object2_packages = set(packages_by_container[object2_num])
188
+ if len(object2_packages) == 0:
189
+ suppression[object2_num] = True
190
+ for object1_num in range(object2_num):
191
+ if not suppression[object1_num]:
192
+ object1_packages = set(packages_by_container[object1_num])
193
+ if len(object2_packages.intersection(object1_packages)) > 0:
194
+ suppression[object2_num] = True
195
+
196
+ final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
197
+ return final_objects
198
+
199
+
200
+ def slot_into_containers(container_objects, package_objects, overlap_threshold=0.5,
201
+ unique_assignment=True, forced_assignment=False):
202
+ """
203
+ Slot a collection of objects into the container they occupy most (the container which holds the largest fraction of the object).
204
+ """
205
+ best_match_scores = []
206
+
207
+ container_assignments = [[] for container in container_objects]
208
+ package_assignments = [[] for package in package_objects]
209
+
210
+ if len(container_objects) == 0 or len(package_objects) == 0:
211
+ return container_assignments, package_assignments, best_match_scores
212
+
213
+ match_scores = defaultdict(dict)
214
+ for package_num, package in enumerate(package_objects):
215
+ match_scores = []
216
+ package_rect = Rect(package['bbox'])
217
+ package_area = package_rect.get_area() # getArea()
218
+ for container_num, container in enumerate(container_objects):
219
+ container_rect = Rect(container['bbox'])
220
+ intersect_area = container_rect.intersect(package['bbox']).get_area() # getArea()
221
+ overlap_fraction = intersect_area / package_area
222
+ match_scores.append({'container': container, 'container_num': container_num, 'score': overlap_fraction})
223
+
224
+ sorted_match_scores = sort_objects_by_score(match_scores)
225
+
226
+ best_match_score = sorted_match_scores[0]
227
+ best_match_scores.append(best_match_score['score'])
228
+ if forced_assignment or best_match_score['score'] >= overlap_threshold:
229
+ container_assignments[best_match_score['container_num']].append(package_num)
230
+ package_assignments[package_num].append(best_match_score['container_num'])
231
+
232
+ if not unique_assignment: # slot package into all eligible slots
233
+ for match_score in sorted_match_scores[1:]:
234
+ if match_score['score'] >= overlap_threshold:
235
+ container_assignments[match_score['container_num']].append(package_num)
236
+ package_assignments[package_num].append(match_score['container_num'])
237
+ else:
238
+ break
239
+
240
+ return container_assignments, package_assignments, best_match_scores
241
+
242
+
243
+ def sort_objects_by_score(objects, reverse=True):
244
+ """
245
+ Put any set of objects in order from high score to low score.
246
+ """
247
+ if reverse:
248
+ sign = -1
249
+ else:
250
+ sign = 1
251
+ return sorted(objects, key=lambda k: sign*k['score'])
252
+
253
+
254
+ def remove_objects_without_content(page_spans, objects):
255
+ """
256
+ Remove any objects (these can be rows, columns, supercells, etc.) that don't
257
+ have any text associated with them.
258
+ """
259
+ for obj in objects[:]:
260
+ object_text, _ = extract_text_inside_bbox(page_spans, obj['bbox'])
261
+ if len(object_text.strip()) == 0:
262
+ objects.remove(obj)
263
+
264
+
265
+ def extract_text_inside_bbox(spans, bbox):
266
+ """
267
+ Extract the text inside a bounding box.
268
+ """
269
+ bbox_spans = get_bbox_span_subset(spans, bbox)
270
+ bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True)
271
+
272
+ return bbox_text, bbox_spans
273
+
274
+
275
+ def get_bbox_span_subset(spans, bbox, threshold=0.5):
276
+ """
277
+ Reduce the set of spans to those that fall within a bounding box.
278
+
279
+ threshold: the fraction of the span that must overlap with the bbox.
280
+ """
281
+ span_subset = []
282
+ for span in spans:
283
+ if overlaps(span['bbox'], bbox, threshold):
284
+ span_subset.append(span)
285
+ return span_subset
286
+
287
+
288
+ def overlaps(bbox1, bbox2, threshold=0.5):
289
+ """
290
+ Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
291
+ """
292
+ rect1 = Rect(list(bbox1))
293
+ area1 = rect1.get_area() # .getArea()
294
+ if area1 == 0:
295
+ return False
296
+ return rect1.intersect(list(bbox2)).get_area()/area1 >= threshold # getArea()
297
+
298
+
299
+ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
300
+ """
301
+ Convert a collection of page tokens/words/spans into a single text string.
302
+ """
303
+
304
+ if join_with_space:
305
+ join_char = " "
306
+ else:
307
+ join_char = ""
308
+ spans_copy = spans[:]
309
+
310
+ if remove_integer_superscripts:
311
+ for span in spans:
312
+ flags = span['flags']
313
+ if flags & 2**0: # superscript flag
314
+ if is_int(span['text']):
315
+ spans_copy.remove(span)
316
+ else:
317
+ span['superscript'] = True
318
+
319
+ if len(spans_copy) == 0:
320
+ return ""
321
+
322
+ spans_copy.sort(key=lambda span: span['span_num'])
323
+ spans_copy.sort(key=lambda span: span['line_num'])
324
+ spans_copy.sort(key=lambda span: span['block_num'])
325
+
326
+ # Force the span at the end of every line within a block to have exactly one space
327
+ # unless the line ends with a space or ends with a non-space followed by a hyphen
328
+ line_texts = []
329
+ line_span_texts = [spans_copy[0]['text']]
330
+ for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]):
331
+ if not span1['block_num'] == span2['block_num'] or not span1['line_num'] == span2['line_num']:
332
+ line_text = join_char.join(line_span_texts).strip()
333
+ if (len(line_text) > 0
334
+ and not line_text[-1] == ' '
335
+ and not (len(line_text) > 1 and line_text[-1] == "-" and not line_text[-2] == ' ')):
336
+ if not join_with_space:
337
+ line_text += ' '
338
+ line_texts.append(line_text)
339
+ line_span_texts = [span2['text']]
340
+ else:
341
+ line_span_texts.append(span2['text'])
342
+ line_text = join_char.join(line_span_texts)
343
+ line_texts.append(line_text)
344
+
345
+ return join_char.join(line_texts).strip()
346
+
347
+
348
+ def sort_objects_left_to_right(objs):
349
+ """
350
+ Put the objects in order from left to right.
351
+ """
352
+ return sorted(objs, key=lambda k: k['bbox'][0] + k['bbox'][2])
353
+
354
+
355
+ def sort_objects_top_to_bottom(objs):
356
+ """
357
+ Put the objects in order from top to bottom.
358
+ """
359
+ return sorted(objs, key=lambda k: k['bbox'][1] + k['bbox'][3])
360
+
361
+
362
+ def align_columns(columns, bbox):
363
+ """
364
+ For every column, align the top and bottom boundaries to the final
365
+ table bounding box.
366
+ """
367
+ try:
368
+ for column in columns:
369
+ column['bbox'][1] = bbox[1]
370
+ column['bbox'][3] = bbox[3]
371
+ except Exception as err:
372
+ print("Could not align columns: {}".format(err))
373
+ pass
374
+
375
+ return columns
376
+
377
+
378
+ def align_rows(rows, bbox):
379
+ """
380
+ For every row, align the left and right boundaries to the final
381
+ table bounding box.
382
+ """
383
+ try:
384
+ for row in rows:
385
+ row['bbox'][0] = bbox[0]
386
+ row['bbox'][2] = bbox[2]
387
+ except Exception as err:
388
+ print("Could not align rows: {}".format(err))
389
+ pass
390
+
391
+ return rows
392
+
393
+
394
+ def refine_table_structures(table_bbox, table_structures, page_spans, class_thresholds):
395
+ """
396
+ Apply operations to the detected table structure objects such as
397
+ thresholding, NMS, and alignment.
398
+ """
399
+ rows = table_structures["rows"]
400
+ columns = table_structures['columns']
401
+
402
+ #columns = fill_column_gaps(columns, table_bbox)
403
+ #rows = fill_row_gaps(rows, table_bbox)
404
+
405
+ # Process the headers
406
+ headers = table_structures['headers']
407
+ headers = apply_threshold(headers, class_thresholds["table column header"])
408
+ headers = nms(headers)
409
+ headers = align_headers(headers, rows)
410
+
411
+ # Process supercells
412
+ supercells = [elem for elem in table_structures['supercells'] if not elem['subheader']]
413
+ subheaders = [elem for elem in table_structures['supercells'] if elem['subheader']]
414
+ supercells = apply_threshold(supercells, class_thresholds["table spanning cell"])
415
+ subheaders = apply_threshold(subheaders, class_thresholds["table projected row header"])
416
+ supercells += subheaders
417
+ # Align before NMS for supercells because alignment brings them into agreement
418
+ # with rows and columns first; if supercells still overlap after this operation,
419
+ # the threshold for NMS can basically be lowered to just above 0
420
+ supercells = align_supercells(supercells, rows, columns)
421
+ supercells = nms_supercells(supercells)
422
+
423
+ header_supercell_tree(supercells)
424
+
425
+ table_structures['columns'] = columns
426
+ table_structures['rows'] = rows
427
+ table_structures['supercells'] = supercells
428
+ table_structures['headers'] = headers
429
+
430
+ return table_structures
431
+
432
+
433
+ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_metric="score", keep_higher=True):
434
+ """
435
+ A customizable version of non-maxima suppression (NMS).
436
+
437
+ Default behavior: If a lower-confidence object overlaps more than 5% of its area
438
+ with a higher-confidence object, remove the lower-confidence object.
439
+
440
+ objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
441
+ match_criteria: how to measure how much two objects "overlap"
442
+ match_threshold: the cutoff for determining that overlap requires suppression of one object
443
+ keep_metric: which metric to use to determine the object to keep
444
+ keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
445
+ """
446
+ if len(objects) == 0:
447
+ return []
448
+
449
+ if keep_metric=="score":
450
+ objects = sort_objects_by_score(objects, reverse=keep_higher)
451
+ elif keep_metric=="area":
452
+ objects = sort_objects_by_area(objects, reverse=keep_higher)
453
+
454
+ num_objects = len(objects)
455
+ suppression = [False for obj in objects]
456
+
457
+ for object2_num in range(1, num_objects):
458
+ object2_rect = Rect(objects[object2_num]['bbox'])
459
+ object2_area = object2_rect.get_area() # .getArea()
460
+ for object1_num in range(object2_num):
461
+ if not suppression[object1_num]:
462
+ object1_rect = Rect(objects[object1_num]['bbox'])
463
+ object1_area = object1_rect.get_area() # .getArea()
464
+ intersect_area = object1_rect.intersect(object2_rect).get_area() # .getArea()
465
+ try:
466
+ if match_criteria=="object1_overlap":
467
+ metric = intersect_area / object1_area
468
+ elif match_criteria=="object2_overlap":
469
+ metric = intersect_area / object2_area
470
+ elif match_criteria=="iou":
471
+ metric = intersect_area / (object1_area + object2_area - intersect_area)
472
+ if metric >= match_threshold:
473
+ suppression[object2_num] = True
474
+ break
475
+ except Exception:
476
+ # Intended to recover from divide-by-zero
477
+ pass
478
+
479
+ return [obj for idx, obj in enumerate(objects) if not suppression[idx]]
480
+
481
+
482
+ def align_headers(headers, rows):
483
+ """
484
+ Adjust the header boundary to be the convex hull of the rows it intersects
485
+ at least 50% of the height of.
486
+
487
+ For now, we are not supporting tables with multiple headers, so we need to
488
+ eliminate anything besides the top-most header.
489
+ """
490
+
491
+ aligned_headers = []
492
+
493
+ for row in rows:
494
+ row['header'] = False
495
+
496
+ header_row_nums = []
497
+ for header in headers:
498
+ for row_num, row in enumerate(rows):
499
+ row_height = row['bbox'][3] - row['bbox'][1]
500
+ min_row_overlap = max(row['bbox'][1], header['bbox'][1])
501
+ max_row_overlap = min(row['bbox'][3], header['bbox'][3])
502
+ overlap_height = max_row_overlap - min_row_overlap
503
+ if overlap_height / row_height >= 0.5:
504
+ header_row_nums.append(row_num)
505
+
506
+ if len(header_row_nums) == 0:
507
+ return aligned_headers
508
+
509
+ header_rect = Rect()
510
+ if header_row_nums[0] > 0:
511
+ header_row_nums = list(range(header_row_nums[0]+1)) + header_row_nums
512
+
513
+ last_row_num = -1
514
+ for row_num in header_row_nums:
515
+ if row_num == last_row_num + 1:
516
+ row = rows[row_num]
517
+ row['header'] = True
518
+ header_rect = header_rect.include_rect(row['bbox'])
519
+ last_row_num = row_num
520
+ else:
521
+ # Break as soon as a non-header row is encountered.
522
+ # This ignores any subsequent rows in the table labeled as a header.
523
+ # Having more than 1 header is not supported currently.
524
+ break
525
+
526
+ header = {'bbox': list(header_rect)}
527
+ aligned_headers.append(header)
528
+
529
+ return aligned_headers
530
+
531
+
532
+ def align_supercells(supercells, rows, columns):
533
+ """
534
+ For each supercell, align it to the rows it intersects 50% of the height of,
535
+ and the columns it intersects 50% of the width of.
536
+ Eliminate supercells for which there are no rows and columns it intersects 50% with.
537
+ """
538
+ aligned_supercells = []
539
+
540
+ for supercell in supercells:
541
+ supercell['header'] = False
542
+ row_bbox_rect = None
543
+ col_bbox_rect = None
544
+ intersecting_header_rows = set()
545
+ intersecting_data_rows = set()
546
+ for row_num, row in enumerate(rows):
547
+ row_height = row['bbox'][3] - row['bbox'][1]
548
+ supercell_height = supercell['bbox'][3] - supercell['bbox'][1]
549
+ min_row_overlap = max(row['bbox'][1], supercell['bbox'][1])
550
+ max_row_overlap = min(row['bbox'][3], supercell['bbox'][3])
551
+ overlap_height = max_row_overlap - min_row_overlap
552
+ if 'span' in supercell:
553
+ overlap_fraction = max(overlap_height/row_height,
554
+ overlap_height/supercell_height)
555
+ else:
556
+ overlap_fraction = overlap_height / row_height
557
+ if overlap_fraction >= 0.5:
558
+ if 'header' in row and row['header']:
559
+ intersecting_header_rows.add(row_num)
560
+ else:
561
+ intersecting_data_rows.add(row_num)
562
+
563
+ # Supercell cannot span across the header boundary; eliminate whichever
564
+ # group of rows is the smallest
565
+ supercell['header'] = False
566
+ if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0:
567
+ if len(intersecting_data_rows) > len(intersecting_header_rows):
568
+ intersecting_header_rows = set()
569
+ else:
570
+ intersecting_data_rows = set()
571
+ if len(intersecting_header_rows) > 0:
572
+ supercell['header'] = True
573
+ elif 'span' in supercell:
574
+ continue # Require span supercell to be in the header
575
+ intersecting_rows = intersecting_data_rows.union(intersecting_header_rows)
576
+ # Determine vertical span of aligned supercell
577
+ for row_num in intersecting_rows:
578
+ if row_bbox_rect is None:
579
+ row_bbox_rect = Rect(rows[row_num]['bbox'])
580
+ else:
581
+ row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]['bbox'])
582
+ if row_bbox_rect is None:
583
+ continue
584
+
585
+ intersecting_cols = []
586
+ for col_num, col in enumerate(columns):
587
+ col_width = col['bbox'][2] - col['bbox'][0]
588
+ supercell_width = supercell['bbox'][2] - supercell['bbox'][0]
589
+ min_col_overlap = max(col['bbox'][0], supercell['bbox'][0])
590
+ max_col_overlap = min(col['bbox'][2], supercell['bbox'][2])
591
+ overlap_width = max_col_overlap - min_col_overlap
592
+ if 'span' in supercell:
593
+ overlap_fraction = max(overlap_width/col_width,
594
+ overlap_width/supercell_width)
595
+ # Multiply by 2 effectively lowers the threshold to 0.25
596
+ if supercell['header']:
597
+ overlap_fraction = overlap_fraction * 2
598
+ else:
599
+ overlap_fraction = overlap_width / col_width
600
+ if overlap_fraction >= 0.5:
601
+ intersecting_cols.append(col_num)
602
+ if col_bbox_rect is None:
603
+ col_bbox_rect = Rect(col['bbox'])
604
+ else:
605
+ col_bbox_rect = col_bbox_rect.include_rect(col['bbox'])
606
+ if col_bbox_rect is None:
607
+ continue
608
+
609
+ supercell_bbox = list(row_bbox_rect.intersect(col_bbox_rect))
610
+ supercell['bbox'] = supercell_bbox
611
+
612
+ # Only a true supercell if it joins across multiple rows or columns
613
+ if (len(intersecting_rows) > 0 and len(intersecting_cols) > 0
614
+ and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1)):
615
+ supercell['row_numbers'] = list(intersecting_rows)
616
+ supercell['column_numbers'] = intersecting_cols
617
+ aligned_supercells.append(supercell)
618
+
619
+ # A span supercell in the header means there must be supercells above it in the header
620
+ if 'span' in supercell and supercell['header'] and len(supercell['column_numbers']) > 1:
621
+ for row_num in range(0, min(supercell['row_numbers'])):
622
+ new_supercell = {'row_numbers': [row_num], 'column_numbers': supercell['column_numbers'],
623
+ 'score': supercell['score'], 'propagated': True}
624
+ new_supercell_columns = [columns[idx] for idx in supercell['column_numbers']]
625
+ new_supercell_rows = [rows[idx] for idx in supercell['row_numbers']]
626
+ bbox = [min([column['bbox'][0] for column in new_supercell_columns]),
627
+ min([row['bbox'][1] for row in new_supercell_rows]),
628
+ max([column['bbox'][2] for column in new_supercell_columns]),
629
+ max([row['bbox'][3] for row in new_supercell_rows])]
630
+ new_supercell['bbox'] = bbox
631
+ aligned_supercells.append(new_supercell)
632
+
633
+ return aligned_supercells
634
+
635
+
636
+ def nms_supercells(supercells):
637
+ """
638
+ A NMS scheme for supercells that first attempts to shrink supercells to
639
+ resolve overlap.
640
+ If two supercells overlap the same (sub)cell, shrink the lower confidence
641
+ supercell to resolve the overlap. If shrunk supercell is empty, remove it.
642
+ """
643
+
644
+ supercells = sort_objects_by_score(supercells)
645
+ num_supercells = len(supercells)
646
+ suppression = [False for supercell in supercells]
647
+
648
+ for supercell2_num in range(1, num_supercells):
649
+ supercell2 = supercells[supercell2_num]
650
+ for supercell1_num in range(supercell2_num):
651
+ supercell1 = supercells[supercell1_num]
652
+ remove_supercell_overlap(supercell1, supercell2)
653
+ if ((len(supercell2['row_numbers']) < 2 and len(supercell2['column_numbers']) < 2)
654
+ or len(supercell2['row_numbers']) == 0 or len(supercell2['column_numbers']) == 0):
655
+ suppression[supercell2_num] = True
656
+
657
+ return [obj for idx, obj in enumerate(supercells) if not suppression[idx]]
658
+
659
+
660
+ def header_supercell_tree(supercells):
661
+ """
662
+ Make sure no supercell in the header is below more than one supercell in any row above it.
663
+ The cells in the header form a tree, but a supercell with more than one supercell in a row
664
+ above it means that some cell has more than one parent, which is not allowed. Eliminate
665
+ any supercell that would cause this to be violated.
666
+ """
667
+ header_supercells = [supercell for supercell in supercells if 'header' in supercell and supercell['header']]
668
+ header_supercells = sort_objects_by_score(header_supercells)
669
+
670
+ for header_supercell in header_supercells[:]:
671
+ ancestors_by_row = defaultdict(int)
672
+ min_row = min(header_supercell['row_numbers'])
673
+ for header_supercell2 in header_supercells:
674
+ max_row2 = max(header_supercell2['row_numbers'])
675
+ if max_row2 < min_row:
676
+ if (set(header_supercell['column_numbers']).issubset(
677
+ set(header_supercell2['column_numbers']))):
678
+ for row2 in header_supercell2['row_numbers']:
679
+ ancestors_by_row[row2] += 1
680
+ for row in range(0, min_row):
681
+ if not ancestors_by_row[row] == 1:
682
+ supercells.remove(header_supercell)
683
+ break
684
+
685
+
686
+ def table_structure_to_cells(table_structures, table_spans, table_bbox):
687
+ """
688
+ Assuming the row, column, supercell, and header bounding boxes have
689
+ been refined into a set of consistent table structures, process these
690
+ table structures into table cells. This is a universal representation
691
+ format for the table, which can later be exported to Pandas or CSV formats.
692
+ Classify the cells as header/access cells or data cells
693
+ based on if they intersect with the header bounding box.
694
+ """
695
+ columns = table_structures['columns']
696
+ rows = table_structures['rows']
697
+ supercells = table_structures['supercells']
698
+ cells = []
699
+ subcells = []
700
+
701
+ # Identify complete cells and subcells
702
+ for column_num, column in enumerate(columns):
703
+ for row_num, row in enumerate(rows):
704
+ column_rect = Rect(list(column['bbox']))
705
+ row_rect = Rect(list(row['bbox']))
706
+ cell_rect = row_rect.intersect(column_rect)
707
+ header = 'header' in row and row['header']
708
+ cell = {'bbox': list(cell_rect), 'column_nums': [column_num], 'row_nums': [row_num],
709
+ 'header': header}
710
+
711
+ cell['subcell'] = False
712
+ for supercell in supercells:
713
+ supercell_rect = Rect(list(supercell['bbox']))
714
+ if (supercell_rect.intersect(cell_rect).get_area() # .getArea()
715
+ / cell_rect.get_area()) > 0.5: # getArea()
716
+ cell['subcell'] = True
717
+ break
718
+
719
+ if cell['subcell']:
720
+ subcells.append(cell)
721
+ else:
722
+ #cell_text = extract_text_inside_bbox(table_spans, cell['bbox'])
723
+ #cell['cell_text'] = cell_text
724
+ cell['subheader'] = False
725
+ cells.append(cell)
726
+
727
+ for supercell in supercells:
728
+ supercell_rect = Rect(list(supercell['bbox']))
729
+ cell_columns = set()
730
+ cell_rows = set()
731
+ cell_rect = None
732
+ header = True
733
+ for subcell in subcells:
734
+ subcell_rect = Rect(list(subcell['bbox']))
735
+ subcell_rect_area = subcell_rect.get_area() # .getArea()
736
+ if (subcell_rect.intersect(supercell_rect).get_area() # .getArea()
737
+ / subcell_rect_area) > 0.5:
738
+ if cell_rect is None:
739
+ cell_rect = Rect(list(subcell['bbox']))
740
+ else:
741
+ cell_rect.include_rect(Rect(list(subcell['bbox'])))
742
+ cell_rows = cell_rows.union(set(subcell['row_nums']))
743
+ cell_columns = cell_columns.union(set(subcell['column_nums']))
744
+ # By convention here, all subcells must be classified
745
+ # as header cells for a supercell to be classified as a header cell;
746
+ # otherwise, this could lead to a non-rectangular header region
747
+ header = header and 'header' in subcell and subcell['header']
748
+ if len(cell_rows) > 0 and len(cell_columns) > 0:
749
+ cell = {'bbox': list(cell_rect), 'column_nums': list(cell_columns), 'row_nums': list(cell_rows),
750
+ 'header': header, 'subheader': supercell['subheader']}
751
+ cells.append(cell)
752
+
753
+ # Compute a confidence score based on how well the page tokens
754
+ # slot into the cells reported by the model
755
+ _, _, cell_match_scores = slot_into_containers(cells, table_spans)
756
+ try:
757
+ mean_match_score = sum(cell_match_scores) / len(cell_match_scores)
758
+ min_match_score = min(cell_match_scores)
759
+ confidence_score = (mean_match_score + min_match_score)/2
760
+ except:
761
+ confidence_score = 0
762
+
763
+ # Dilate rows and columns before final extraction
764
+ #dilated_columns = fill_column_gaps(columns, table_bbox)
765
+ dilated_columns = columns
766
+ #dilated_rows = fill_row_gaps(rows, table_bbox)
767
+ dilated_rows = rows
768
+ for cell in cells:
769
+ column_rect = Rect()
770
+ for column_num in cell['column_nums']:
771
+ column_rect.include_rect(list(dilated_columns[column_num]['bbox']))
772
+ row_rect = Rect()
773
+ for row_num in cell['row_nums']:
774
+ row_rect.include_rect(list(dilated_rows[row_num]['bbox']))
775
+ cell_rect = column_rect.intersect(row_rect)
776
+ cell['bbox'] = list(cell_rect)
777
+
778
+ span_nums_by_cell, _, _ = slot_into_containers(cells, table_spans, overlap_threshold=0.001,
779
+ unique_assignment=True, forced_assignment=False)
780
+
781
+ for cell, cell_span_nums in zip(cells, span_nums_by_cell):
782
+ cell_spans = [table_spans[num] for num in cell_span_nums]
783
+ # TODO: Refine how text is extracted; should be character-based, not span-based;
784
+ # but need to associate
785
+ # cell['cell_text'] = extract_text_from_spans(cell_spans, remove_integer_superscripts=False) # TODO
786
+ cell['spans'] = cell_spans
787
+
788
+ # Adjust the row, column, and cell bounding boxes to reflect the extracted text
789
+ num_rows = len(rows)
790
+ rows = sort_objects_top_to_bottom(rows)
791
+ num_columns = len(columns)
792
+ columns = sort_objects_left_to_right(columns)
793
+ min_y_values_by_row = defaultdict(list)
794
+ max_y_values_by_row = defaultdict(list)
795
+ min_x_values_by_column = defaultdict(list)
796
+ max_x_values_by_column = defaultdict(list)
797
+ for cell in cells:
798
+ min_row = min(cell["row_nums"])
799
+ max_row = max(cell["row_nums"])
800
+ min_column = min(cell["column_nums"])
801
+ max_column = max(cell["column_nums"])
802
+ for span in cell['spans']:
803
+ min_x_values_by_column[min_column].append(span['bbox'][0])
804
+ min_y_values_by_row[min_row].append(span['bbox'][1])
805
+ max_x_values_by_column[max_column].append(span['bbox'][2])
806
+ max_y_values_by_row[max_row].append(span['bbox'][3])
807
+ for row_num, row in enumerate(rows):
808
+ if len(min_x_values_by_column[0]) > 0:
809
+ row['bbox'][0] = min(min_x_values_by_column[0])
810
+ if len(min_y_values_by_row[row_num]) > 0:
811
+ row['bbox'][1] = min(min_y_values_by_row[row_num])
812
+ if len(max_x_values_by_column[num_columns-1]) > 0:
813
+ row['bbox'][2] = max(max_x_values_by_column[num_columns-1])
814
+ if len(max_y_values_by_row[row_num]) > 0:
815
+ row['bbox'][3] = max(max_y_values_by_row[row_num])
816
+ for column_num, column in enumerate(columns):
817
+ if len(min_x_values_by_column[column_num]) > 0:
818
+ column['bbox'][0] = min(min_x_values_by_column[column_num])
819
+ if len(min_y_values_by_row[0]) > 0:
820
+ column['bbox'][1] = min(min_y_values_by_row[0])
821
+ if len(max_x_values_by_column[column_num]) > 0:
822
+ column['bbox'][2] = max(max_x_values_by_column[column_num])
823
+ if len(max_y_values_by_row[num_rows-1]) > 0:
824
+ column['bbox'][3] = max(max_y_values_by_row[num_rows-1])
825
+ for cell in cells:
826
+ row_rect = Rect()
827
+ column_rect = Rect()
828
+ for row_num in cell['row_nums']:
829
+ row_rect.include_rect(list(rows[row_num]['bbox']))
830
+ for column_num in cell['column_nums']:
831
+ column_rect.include_rect(list(columns[column_num]['bbox']))
832
+ cell_rect = row_rect.intersect(column_rect)
833
+ if cell_rect.get_area() > 0: # getArea()
834
+ cell['bbox'] = list(cell_rect)
835
+ pass
836
+
837
+ return cells, confidence_score
838
+
839
+
840
+ def remove_supercell_overlap(supercell1, supercell2):
841
+ """
842
+ This function resolves overlap between supercells (supercells must be
843
+ disjoint) by iteratively shrinking supercells by the fewest grid cells
844
+ necessary to resolve the overlap.
845
+ Example:
846
+ If two supercells overlap at grid cell (R, C), and supercell #1 is less
847
+ confident than supercell #2, we eliminate either row R from supercell #1
848
+ or column C from supercell #1 by comparing the number of columns in row R
849
+ versus the number of rows in column C. If the number of columns in row R
850
+ is less than the number of rows in column C, we eliminate row R from
851
+ supercell #1. This resolves the overlap by removing fewer grid cells from
852
+ supercell #1 than if we eliminated column C from it.
853
+ """
854
+ common_rows = set(supercell1['row_numbers']).intersection(set(supercell2['row_numbers']))
855
+ common_columns = set(supercell1['column_numbers']).intersection(set(supercell2['column_numbers']))
856
+
857
+ # While the supercells have overlapping grid cells, continue shrinking the less-confident
858
+ # supercell one row or one column at a time
859
+ while len(common_rows) > 0 and len(common_columns) > 0:
860
+ # Try to shrink the supercell as little as possible to remove the overlap;
861
+ # if the supercell has fewer rows than columns, remove an overlapping column,
862
+ # because this removes fewer grid cells from the supercell;
863
+ # otherwise remove an overlapping row
864
+ if len(supercell2['row_numbers']) < len(supercell2['column_numbers']):
865
+ min_column = min(supercell2['column_numbers'])
866
+ max_column = max(supercell2['column_numbers'])
867
+ if max_column in common_columns:
868
+ common_columns.remove(max_column)
869
+ supercell2['column_numbers'].remove(max_column)
870
+ elif min_column in common_columns:
871
+ common_columns.remove(min_column)
872
+ supercell2['column_numbers'].remove(min_column)
873
+ else:
874
+ supercell2['column_numbers'] = []
875
+ common_columns = set()
876
+ else:
877
+ min_row = min(supercell2['row_numbers'])
878
+ max_row = max(supercell2['row_numbers'])
879
+ if max_row in common_rows:
880
+ common_rows.remove(max_row)
881
+ supercell2['row_numbers'].remove(max_row)
882
+ elif min_row in common_rows:
883
+ common_rows.remove(min_row)
884
+ supercell2['row_numbers'].remove(min_row)
885
+ else:
886
+ supercell2['row_numbers'] = []
887
+ common_rows = set()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ #git+https://github.com/nielsrogge/transformers.git@convert_table_transformer_new_checkpoints
3
+ transformers
4
+ easyocr
5
+ matplotlib
6
+ Pillow
7
+ pandas
8
+ ultralytics
9
+ PyMuPDF
10
+ opencv-python
11
+ gradio
12
+ paddlepaddle-gpu
13
+ paddleocr
tatr-app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+ from matplotlib.patches import Patch
4
+ import io
5
+ from PIL import Image, ImageDraw
6
+ import numpy as np
7
+ import csv
8
+ import pandas as pd
9
+
10
+ from torchvision import transforms
11
+
12
+ from transformers import AutoModelForObjectDetection
13
+ import torch
14
+
15
+ import easyocr
16
+
17
+ import gradio as gr
18
+
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+
23
+ class MaxResize(object):
24
+ def __init__(self, max_size=800):
25
+ self.max_size = max_size
26
+
27
+ def __call__(self, image):
28
+ width, height = image.size
29
+ current_max_size = max(width, height)
30
+ scale = self.max_size / current_max_size
31
+ resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
32
+
33
+ return resized_image
34
+
35
+ detection_transform = transforms.Compose([
36
+ MaxResize(800),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
+ ])
40
+
41
+ structure_transform = transforms.Compose([
42
+ MaxResize(1000),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ # load table detection model
48
+ # processor = TableTransformerImageProcessor(max_size=800)
49
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device)
50
+
51
+ # load table structure recognition model
52
+ # structure_processor = TableTransformerImageProcessor(max_size=1000)
53
+ structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device)
54
+
55
+ # load EasyOCR reader
56
+ reader = easyocr.Reader(['en'])
57
+
58
+
59
+ # for output bounding box post-processing
60
+ def box_cxcywh_to_xyxy(x):
61
+ x_c, y_c, w, h = x.unbind(-1)
62
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
63
+ return torch.stack(b, dim=1)
64
+
65
+
66
+ def rescale_bboxes(out_bbox, size):
67
+ width, height = size
68
+ boxes = box_cxcywh_to_xyxy(out_bbox)
69
+ boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
70
+ return boxes
71
+
72
+
73
+ def outputs_to_objects(outputs, img_size, id2label):
74
+ m = outputs.logits.softmax(-1).max(-1)
75
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
76
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
77
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
78
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
79
+
80
+ objects = []
81
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
82
+ class_label = id2label[int(label)]
83
+ if not class_label == 'no object':
84
+ objects.append({'label': class_label, 'score': float(score),
85
+ 'bbox': [float(elem) for elem in bbox]})
86
+
87
+ return objects
88
+
89
+
90
+ def fig2img(fig):
91
+ """Convert a Matplotlib figure to a PIL Image and return it"""
92
+ buf = io.BytesIO()
93
+ fig.savefig(buf)
94
+ buf.seek(0)
95
+ image = Image.open(buf)
96
+ return image
97
+
98
+
99
+ def visualize_detected_tables(img, det_tables):
100
+ plt.imshow(img, interpolation="lanczos")
101
+ fig = plt.gcf()
102
+ fig.set_size_inches(20, 20)
103
+ ax = plt.gca()
104
+
105
+ for det_table in det_tables:
106
+ bbox = det_table['bbox']
107
+
108
+ if det_table['label'] == 'table':
109
+ facecolor = (1, 0, 0.45)
110
+ edgecolor = (1, 0, 0.45)
111
+ alpha = 0.3
112
+ linewidth = 2
113
+ hatch='//////'
114
+ elif det_table['label'] == 'table rotated':
115
+ facecolor = (0.95, 0.6, 0.1)
116
+ edgecolor = (0.95, 0.6, 0.1)
117
+ alpha = 0.3
118
+ linewidth = 2
119
+ hatch='//////'
120
+ else:
121
+ continue
122
+
123
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
124
+ edgecolor='none',facecolor=facecolor, alpha=0.1)
125
+ ax.add_patch(rect)
126
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
127
+ edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
128
+ ax.add_patch(rect)
129
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
130
+ edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
131
+ ax.add_patch(rect)
132
+
133
+ plt.xticks([], [])
134
+ plt.yticks([], [])
135
+
136
+ legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
137
+ label='Table', hatch='//////', alpha=0.3),
138
+ Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
139
+ label='Table (rotated)', hatch='//////', alpha=0.3)]
140
+ plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
141
+ fontsize=10, ncol=2)
142
+ plt.gcf().set_size_inches(10, 10)
143
+ plt.axis('off')
144
+
145
+ return fig
146
+
147
+
148
+ def detect_and_crop_table(image):
149
+ # prepare image for the model
150
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
151
+ pixel_values = detection_transform(image).unsqueeze(0).to(device)
152
+
153
+ # forward pass
154
+ with torch.no_grad():
155
+ outputs = model(pixel_values)
156
+
157
+ # postprocess to get detected tables
158
+ id2label = model.config.id2label
159
+ id2label[len(model.config.id2label)] = "no object"
160
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
161
+
162
+ # visualize
163
+ # fig = visualize_detected_tables(image, detected_tables)
164
+ # image = fig2img(fig)
165
+
166
+ # crop first detected table out of image
167
+ cropped_table = image.crop(detected_tables[0]["bbox"])
168
+
169
+ return cropped_table
170
+
171
+
172
+ def recognize_table(image):
173
+ # prepare image for the model
174
+ # pixel_values = structure_processor(images=image, return_tensors="pt").pixel_values
175
+ pixel_values = structure_transform(image).unsqueeze(0).to(device)
176
+
177
+ # forward pass
178
+ with torch.no_grad():
179
+ outputs = structure_model(pixel_values)
180
+
181
+ # postprocess to get individual elements
182
+ id2label = structure_model.config.id2label
183
+ id2label[len(structure_model.config.id2label)] = "no object"
184
+ cells = outputs_to_objects(outputs, image.size, id2label)
185
+
186
+ # visualize cells on cropped table
187
+ draw = ImageDraw.Draw(image)
188
+
189
+ for cell in cells:
190
+ draw.rectangle(cell["bbox"], outline="red")
191
+
192
+ return image, cells
193
+
194
+
195
+ def get_cell_coordinates_by_row(table_data):
196
+ # Extract rows and columns
197
+ rows = [entry for entry in table_data if entry['label'] == 'table row']
198
+ columns = [entry for entry in table_data if entry['label'] == 'table column']
199
+
200
+ # Sort rows and columns by their Y and X coordinates, respectively
201
+ rows.sort(key=lambda x: x['bbox'][1])
202
+ columns.sort(key=lambda x: x['bbox'][0])
203
+
204
+ # Function to find cell coordinates
205
+ def find_cell_coordinates(row, column):
206
+ cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
207
+ return cell_bbox
208
+
209
+ # Generate cell coordinates and count cells in each row
210
+ cell_coordinates = []
211
+
212
+ for row in rows:
213
+ row_cells = []
214
+ for column in columns:
215
+ cell_bbox = find_cell_coordinates(row, column)
216
+ row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
217
+
218
+ # Sort cells in the row by X coordinate
219
+ row_cells.sort(key=lambda x: x['column'][0])
220
+
221
+ # Append row information to cell_coordinates
222
+ cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
223
+
224
+ # Sort rows from top to bottom
225
+ cell_coordinates.sort(key=lambda x: x['row'][1])
226
+
227
+ return cell_coordinates
228
+
229
+
230
+ def apply_ocr(cell_coordinates, cropped_table):
231
+ # let's OCR row by row
232
+ data = dict()
233
+ max_num_columns = 0
234
+ for idx, row in enumerate(cell_coordinates):
235
+ row_text = []
236
+ for cell in row["cells"]:
237
+ # crop cell out of image
238
+ cell_image = np.array(cropped_table.crop(cell["cell"]))
239
+ # apply OCR
240
+ result = reader.readtext(np.array(cell_image))
241
+ if len(result) > 0:
242
+ text = " ".join([x[1] for x in result])
243
+ row_text.append(text)
244
+
245
+ if len(row_text) > max_num_columns:
246
+ max_num_columns = len(row_text)
247
+
248
+ data[str(idx)] = row_text
249
+
250
+ # pad rows which don't have max_num_columns elements
251
+ # to make sure all rows have the same number of columns
252
+ for idx, row_data in data.copy().items():
253
+ if len(row_data) != max_num_columns:
254
+ row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
255
+ data[str(idx)] = row_data
256
+
257
+ # write to csv
258
+ with open('output.csv','w') as result_file:
259
+ wr = csv.writer(result_file, dialect='excel')
260
+
261
+ for row, row_text in data.items():
262
+ wr.writerow(row_text)
263
+
264
+ # return as Pandas dataframe
265
+ df = pd.read_csv('output.csv')
266
+
267
+ return df, data
268
+
269
+
270
+ def process_pdf(image):
271
+ cropped_table = detect_and_crop_table(image)
272
+
273
+ image, cells = recognize_table(cropped_table)
274
+
275
+ cell_coordinates = get_cell_coordinates_by_row(cells)
276
+
277
+ df, data = apply_ocr(cell_coordinates, image)
278
+
279
+ return image, df, data
280
+
281
+
282
+ title = "Demo: table detection & recognition with Table Transformer (TATR)."
283
+ description = """Demo for table extraction with the Table Transformer. First, table detection is performed on the input image using https://huggingface.co/microsoft/table-transformer-detection,
284
+ after which the detected table is extracted and https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all is leveraged to recognize the individual rows, columns and cells. OCR is then performed per cell, row by row."""
285
+ examples = [['image.png'], ['mistral_paper.png']]
286
+
287
+ app = gr.Interface(fn=process_pdf,
288
+ inputs=gr.Image(type="pil"),
289
+ outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
290
+ title=title,
291
+ description=description,
292
+ examples=examples)
293
+ app.queue()
294
+ app.launch(debug=True)
yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d7c77d9723839f582e02377794dccee2ed745f72e08f4818d1ed5a7f7c3e591
3
+ size 22520345
yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41877a8b980bd7c28139c267b3c3b1f0ffefb6babeaee8340c1680d2d623a794
3
+ size 22527577