nielsr HF staff commited on
Commit
f1d8127
1 Parent(s): c4e3f03

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+ from matplotlib.patches import Patch
4
+ import io
5
+ from PIL import Image
6
+
7
+ from transformers import TableTransformerImageProcessor, AutoModelForObjectDetection
8
+ import torch
9
+
10
+ import gradio as gr
11
+
12
+ # load table detection model
13
+ processor = TableTransformerImageProcessor(max_size=800)
14
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
15
+
16
+
17
+ # for output bounding box post-processing
18
+ def box_cxcywh_to_xyxy(x):
19
+ x_c, y_c, w, h = x.unbind(-1)
20
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
21
+ return torch.stack(b, dim=1)
22
+
23
+
24
+ def rescale_bboxes(out_bbox, size):
25
+ img_w, img_h = size
26
+ b = box_cxcywh_to_xyxy(out_bbox)
27
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
28
+ return b
29
+
30
+
31
+ def outputs_to_objects(outputs, img_size, id2label):
32
+ m = outputs.logits.softmax(-1).max(-1)
33
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
34
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
35
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
36
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
37
+
38
+ objects = []
39
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
40
+ class_label = id2label[int(label)]
41
+ if not class_label == 'no object':
42
+ objects.append({'label': class_label, 'score': float(score),
43
+ 'bbox': [float(elem) for elem in bbox]})
44
+
45
+ return objects
46
+
47
+
48
+ def fig2img(fig):
49
+ """Convert a Matplotlib figure to a PIL Image and return it"""
50
+ buf = io.BytesIO()
51
+ fig.savefig(buf)
52
+ buf.seek(0)
53
+ img = Image.open(buf)
54
+ return img
55
+
56
+
57
+ def visualize_detected_tables(img, det_tables):
58
+ plt.imshow(img, interpolation="lanczos")
59
+ fig = plt.gcf()
60
+ fig.set_size_inches(20, 20)
61
+ ax = plt.gca()
62
+
63
+ for det_table in det_tables:
64
+ bbox = det_table['bbox']
65
+
66
+ if det_table['label'] == 'table':
67
+ facecolor = (1, 0, 0.45)
68
+ edgecolor = (1, 0, 0.45)
69
+ alpha = 0.3
70
+ linewidth = 2
71
+ hatch='//////'
72
+ elif det_table['label'] == 'table rotated':
73
+ facecolor = (0.95, 0.6, 0.1)
74
+ edgecolor = (0.95, 0.6, 0.1)
75
+ alpha = 0.3
76
+ linewidth = 2
77
+ hatch='//////'
78
+ else:
79
+ continue
80
+
81
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
82
+ edgecolor='none',facecolor=facecolor, alpha=0.1)
83
+ ax.add_patch(rect)
84
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
85
+ edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
86
+ ax.add_patch(rect)
87
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
88
+ edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
89
+ ax.add_patch(rect)
90
+
91
+ plt.xticks([], [])
92
+ plt.yticks([], [])
93
+
94
+ legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
95
+ label='Table', hatch='//////', alpha=0.3),
96
+ Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
97
+ label='Table (rotated)', hatch='//////', alpha=0.3)]
98
+ plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
99
+ fontsize=10, ncol=2)
100
+ plt.gcf().set_size_inches(10, 10)
101
+ plt.axis('off')
102
+
103
+ return fig
104
+
105
+
106
+ def detect_table(image):
107
+ # prepare image for the model
108
+ pixel_values = processor(image, return_tensors="pt").pixel_values
109
+
110
+ # forward pass
111
+ with torch.no_grad():
112
+ outputs = model(pixel_values)
113
+
114
+ # postprocess to get detected tables
115
+ id2label = model.config.id2label
116
+ id2label[len(model.config.id2label)] = "no object"
117
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
118
+
119
+ # visualize
120
+ fig = visualize_detected_tables(img, detected_tables)
121
+ image = fig2img(fig)
122
+
123
+ return image
124
+
125
+
126
+ title = "Demo: table detection with Table Transformer"
127
+ description = "Demo for the Table Transformer (TATR)."
128
+ examples =[['example_pdf.jpg']]
129
+
130
+ interface = gr.Interface(fn=detect_table,
131
+ inputs=gr.Image(type="pil"),
132
+ outputs=gr.Image(type="pil", label="Detected table"),
133
+ title=title,
134
+ description=description,
135
+ examples=examples,
136
+ enable_queue=True)
137
+ interface.launch(debug=True)