import matplotlib.pyplot as plt import matplotlib.patches as patches from matplotlib.patches import Patch import io from PIL import Image, ImageDraw import numpy as np import csv import pandas as pd from torchvision import transforms from transformers import AutoModelForObjectDetection import torch import easyocr import gradio as gr device = "cuda" if torch.cuda.is_available() else "cpu" class MaxResize(object): def __init__(self, max_size=800): self.max_size = max_size def __call__(self, image): width, height = image.size current_max_size = max(width, height) scale = self.max_size / current_max_size resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) return resized_image detection_transform = transforms.Compose([ MaxResize(800), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) structure_transform = transforms.Compose([ MaxResize(1000), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # load table detection model # processor = TableTransformerImageProcessor(max_size=800) model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device) # load table structure recognition model # structure_processor = TableTransformerImageProcessor(max_size=1000) structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device) # load EasyOCR reader reader = easyocr.Reader(['en']) def outputs_to_objects(outputs, img_size, id2label): m = outputs.logits.softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy())[0] pred_scores = list(m.values.detach().cpu().numpy())[0] pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] objects = [] for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): class_label = id2label[int(label)] if not class_label == 'no object': objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]}) return objects def detect_and_crop_table(image): # prepare image for the model # pixel_values = processor(image, return_tensors="pt").pixel_values pixel_values = detection_transform(image).unsqueeze(0).to(device) # forward pass with torch.no_grad(): outputs = model(pixel_values) # postprocess to get detected tables id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) # visualize # fig = visualize_detected_tables(image, detected_tables) # image = fig2img(fig) # crop first detected table out of image cropped_table = image.crop(detected_tables[0]["bbox"]) return cropped_table def process_pdf(): print('process_pdf') cropped_table = detect_and_crop_table(image) # image, cells = recognize_table(cropped_table) # cell_coordinates = get_cell_coordinates_by_row(cells) # df, data = apply_ocr(cell_coordinates, image) return cropped_table # return image, df, data # return [], [], [] title = "Sheriff's Demo: Table Detection & Recognition with Table Transformer (TATR)." description = """A demo by M Sheriff for table extraction with the Table Transformer. First, table detection is performed on the input image using https://huggingface.co/microsoft/table-transformer-detection, after which the detected table is extracted and https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all recognizes the individual rows, columns and cells. OCR is then performed per cell, row by row.""" # examples = [['image.png'], ['mistral_paper.png']] app = gr.Interface(fn=process_pdf, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil", label="Detected table")], # outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")], title=title, description=description, # examples=examples ) app.queue() app.launch(debug=True)