""" Microsoft Table Transformer Extension By Neils: https://docs.llamaindex.ai/en/stable/examples/multi_modal/multi_modal_pdf_tables.html#experiment-3-let-s-use-microsoft-table-transformer-to-crop-tables-from-the-images-and-see-if-it-gives-the-correct-answer """ import matplotlib.pyplot as plt import matplotlib.patches as patches from matplotlib.patches import Patch import io from PIL import Image, ImageDraw import numpy as np import csv import pandas as pd from torchvision import transforms from transformers import AutoModelForObjectDetection import torch import openai import os import fitz device = "cuda" if torch.cuda.is_available() else "cpu" class MaxResize(object): def __init__(self, max_size=800): self.max_size = max_size def __call__(self, image): width, height = image.size current_max_size = max(width, height) scale = self.max_size / current_max_size resized_image = image.resize( (int(round(scale * width)), int(round(scale * height))) ) return resized_image detection_transform = transforms.Compose( [ MaxResize(800), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) structure_transform = transforms.Compose( [ MaxResize(1000), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) # load table detection model # processor = TableTransformerImageProcessor(max_size=800) model = AutoModelForObjectDetection.from_pretrained( "microsoft/table-transformer-detection", revision="no_timm" ).to(device) # load table structure recognition model # structure_processor = TableTransformerImageProcessor(max_size=1000) structure_model = AutoModelForObjectDetection.from_pretrained( "microsoft/table-transformer-structure-recognition-v1.1-all" ).to(device) # for output bounding box post-processing def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): width, height = size boxes = box_cxcywh_to_xyxy(out_bbox) boxes = boxes * torch.tensor( [width, height, width, height], dtype=torch.float32 ) return boxes def outputs_to_objects(outputs, img_size, id2label): m = outputs.logits.softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy())[0] pred_scores = list(m.values.detach().cpu().numpy())[0] pred_bboxes = outputs["pred_boxes"].detach().cpu()[0] pred_bboxes = [ elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size) ] objects = [] for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): class_label = id2label[int(label)] if not class_label == "no object": objects.append( { "label": class_label, "score": float(score), "bbox": [float(elem) for elem in bbox], } ) return objects def detect_and_crop_save_table( file_path, cropped_table_directory="./table_images/" ): image = Image.open(file_path) filename, _ = os.path.splitext(file_path.split("/")[-1]) if not os.path.exists(cropped_table_directory): os.makedirs(cropped_table_directory) # prepare image for the model # pixel_values = processor(image, return_tensors="pt").pixel_values pixel_values = detection_transform(image).unsqueeze(0).to(device) # forward pass with torch.no_grad(): outputs = model(pixel_values) # postprocess to get detected tables id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) print(f"number of tables detected {len(detected_tables)}") for idx in range(len(detected_tables)): # # crop detected table out of image cropped_table = image.crop(detected_tables[idx]["bbox"]) cropped_table.save(f"./{cropped_table_directory}/{filename}_{idx}.png") def plot_images(image_paths): images_shown = 0 plt.figure(figsize=(16, 9)) for img_path in image_paths: if os.path.isfile(img_path): image = Image.open(img_path) plt.subplot(2, 3, images_shown + 1) plt.imshow(image) plt.xticks([]) plt.yticks([]) images_shown += 1 if images_shown >= 9: break