tatr-demo / app.py
nielsr's picture
nielsr HF staff
Update app.py
35723bb
raw history blame
No virus
5.87 kB
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
from PIL import Image, ImageDraw
from transformers import TableTransformerImageProcessor, AutoModelForObjectDetection
import torch
import gradio as gr
# load table detection model
processor = TableTransformerImageProcessor(max_size=800)
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
# load table structure recognition model
structure_processor = TableTransformerImageProcessor(max_size=1000)
structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
return boxes
def outputs_to_objects(outputs, img_size, id2label):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score),
'bbox': [float(elem) for elem in bbox]})
return objects
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
image = Image.open(buf)
return image
def visualize_detected_tables(img, det_tables):
plt.imshow(img, interpolation="lanczos")
fig = plt.gcf()
fig.set_size_inches(20, 20)
ax = plt.gca()
for det_table in det_tables:
bbox = det_table['bbox']
if det_table['label'] == 'table':
facecolor = (1, 0, 0.45)
edgecolor = (1, 0, 0.45)
alpha = 0.3
linewidth = 2
hatch='//////'
elif det_table['label'] == 'table rotated':
facecolor = (0.95, 0.6, 0.1)
edgecolor = (0.95, 0.6, 0.1)
alpha = 0.3
linewidth = 2
hatch='//////'
else:
continue
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=facecolor, alpha=0.1)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
ax.add_patch(rect)
plt.xticks([], [])
plt.yticks([], [])
legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
label='Table', hatch='//////', alpha=0.3),
Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
label='Table (rotated)', hatch='//////', alpha=0.3)]
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
fontsize=10, ncol=2)
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
return fig
def detect_and_crop_table(image):
# prepare image for the model
pixel_values = processor(image, return_tensors="pt").pixel_values
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
# postprocess to get detected tables
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
# visualize
# fig = visualize_detected_tables(image, detected_tables)
# image = fig2img(fig)
# crop first detected table out of image
cropped_table = image.crop(detected_tables[0]["bbox"])
return cropped_table
def recognize_table(image):
# prepare image for the model
pixel_values = structure_processor(images=cropped_table, return_tensors="pt").pixel_values
# forward pass
with torch.no_grad():
outputs = structure_model(pixel_values)
# postprocess to get individual elements
id2label = structure_modelmodel.config.id2label
id2label[len(structure_modelmodel.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
# visualize cells on cropped table
draw = ImageDraw.Draw(image)
for cell in cells:
draw.rectangle(cell["bbox"], outline="red")
return image
def process_pdf(image):
cropped_table = detect_and_crop_table(image)
image = recognize_table(cropped_table)
return image
title = "Demo: table detection with Table Transformer"
description = "Demo for the Table Transformer (TATR)."
examples =[['image.png']]
app = gr.Interface(fn=process_pdf,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil", label="Detected table"),
title=title,
description=description,
examples=examples)
app.queue()
app.launch(debug=True)