nielsr HF staff commited on
Commit
4e66f95
1 Parent(s): c62ee55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -5
app.py CHANGED
@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
2
  import matplotlib.patches as patches
3
  from matplotlib.patches import Patch
4
  import io
5
- from PIL import Image
6
 
7
  from transformers import TableTransformerImageProcessor, AutoModelForObjectDetection
8
  import torch
@@ -13,6 +13,10 @@ import gradio as gr
13
  processor = TableTransformerImageProcessor(max_size=800)
14
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
15
 
 
 
 
 
16
 
17
  # for output bounding box post-processing
18
  def box_cxcywh_to_xyxy(x):
@@ -103,7 +107,7 @@ def visualize_detected_tables(img, det_tables):
103
  return fig
104
 
105
 
106
- def detect_table(image):
107
  # prepare image for the model
108
  pixel_values = processor(image, return_tensors="pt").pixel_values
109
 
@@ -117,8 +121,41 @@ def detect_table(image):
117
  detected_tables = outputs_to_objects(outputs, image.size, id2label)
118
 
119
  # visualize
120
- fig = visualize_detected_tables(image, detected_tables)
121
- image = fig2img(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  return image
124
 
@@ -127,7 +164,7 @@ title = "Demo: table detection with Table Transformer"
127
  description = "Demo for the Table Transformer (TATR)."
128
  examples =[['image.png']]
129
 
130
- app = gr.Interface(fn=detect_table,
131
  inputs=gr.Image(type="pil"),
132
  outputs=gr.Image(type="pil", label="Detected table"),
133
  title=title,
 
2
  import matplotlib.patches as patches
3
  from matplotlib.patches import Patch
4
  import io
5
+ from PIL import Image, ImageDraw
6
 
7
  from transformers import TableTransformerImageProcessor, AutoModelForObjectDetection
8
  import torch
 
13
  processor = TableTransformerImageProcessor(max_size=800)
14
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
15
 
16
+ # load table structure recognition model
17
+ structure_processor = TableTransformerImageProcessor(max_size=1000)
18
+ structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
19
+
20
 
21
  # for output bounding box post-processing
22
  def box_cxcywh_to_xyxy(x):
 
107
  return fig
108
 
109
 
110
+ def detect_and_crop_table(image):
111
  # prepare image for the model
112
  pixel_values = processor(image, return_tensors="pt").pixel_values
113
 
 
121
  detected_tables = outputs_to_objects(outputs, image.size, id2label)
122
 
123
  # visualize
124
+ # fig = visualize_detected_tables(image, detected_tables)
125
+ # image = fig2img(fig)
126
+
127
+ # crop first detected table out of image
128
+ cropped_table = image.crop(objects[0]["bbox"])
129
+
130
+ return cropped_table
131
+
132
+
133
+ def recognize_table(image):
134
+ # prepare image for the model
135
+ pixel_values = structure_processor(images=cropped_table, return_tensors="pt").pixel_values
136
+
137
+ # forward pass
138
+ with torch.no_grad():
139
+ outputs = structure_model(pixel_values)
140
+
141
+ # postprocess to get individual elements
142
+ id2label = structure_modelmodel.config.id2label
143
+ id2label[len(structure_modelmodel.config.id2label)] = "no object"
144
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
145
+
146
+ # visualize cells on cropped table
147
+ draw = ImageDraw.Draw(image)
148
+
149
+ for cell in cells:
150
+ draw.rectangle(cell["bbox"], outline="red")
151
+
152
+ return image
153
+
154
+
155
+ def process_pdf(image):
156
+ cropped_table = detect_and_crop_table(image)
157
+
158
+ image = recognize_table(cropped_table)
159
 
160
  return image
161
 
 
164
  description = "Demo for the Table Transformer (TATR)."
165
  examples =[['image.png']]
166
 
167
+ app = gr.Interface(fn=process_pdf,
168
  inputs=gr.Image(type="pil"),
169
  outputs=gr.Image(type="pil", label="Detected table"),
170
  title=title,