transformers / app.py
msheriff's picture
Update app.py
df692d6
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
from PIL import Image, ImageDraw
import numpy as np
import csv
import pandas as pd
from torchvision import transforms
from transformers import AutoModelForObjectDetection
import torch
import easyocr
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
return resized_image
detection_transform = transforms.Compose([
MaxResize(800),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# load table detection model
# processor = TableTransformerImageProcessor(max_size=800)
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device)
# load table structure recognition model
# structure_processor = TableTransformerImageProcessor(max_size=1000)
structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device)
# load EasyOCR reader
reader = easyocr.Reader(['en'])
def outputs_to_objects(outputs, img_size, id2label):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score),
'bbox': [float(elem) for elem in bbox]})
return objects
def detect_and_crop_table(image):
# prepare image for the model
# pixel_values = processor(image, return_tensors="pt").pixel_values
pixel_values = detection_transform(image).unsqueeze(0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
# postprocess to get detected tables
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
# visualize
# fig = visualize_detected_tables(image, detected_tables)
# image = fig2img(fig)
# crop first detected table out of image
cropped_table = image.crop(detected_tables[0]["bbox"])
return cropped_table
def process_pdf():
print('process_pdf')
cropped_table = detect_and_crop_table(image)
# image, cells = recognize_table(cropped_table)
# cell_coordinates = get_cell_coordinates_by_row(cells)
# df, data = apply_ocr(cell_coordinates, image)
return cropped_table
# return image, df, data
# return [], [], []
title = "Sheriff's Demo: Table Detection & Recognition with Table Transformer (TATR)."
description = """A demo by M Sheriff for table extraction with the Table Transformer.
First, table detection is performed on the input image using https://huggingface.co/microsoft/table-transformer-detection,
after which the detected table is extracted and https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all recognizes the
individual rows, columns and cells. OCR is then performed per cell, row by row."""
# examples = [['image.png'], ['mistral_paper.png']]
app = gr.Interface(fn=process_pdf,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil", label="Detected table")],
# outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
title=title,
description=description,
# examples=examples
)
app.queue()
app.launch(debug=True)