from transformers import AutoProcessor, Pix2StructForConditionalGeneration import gradio as gr import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from typing import Tuple from PIL import Image import os import sys os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") os.system("git clone https://github.com/microsoft/unilm.git; cd unilm; git checkout 9102ed91f8e56baa31d7ae7e09e0ec98e77d779c; cd ..") sys.path.append("unilm") from unilm.dit.object_detection.ditod import add_vit_config from detectron2.config import CfgNode as CN from detectron2.config import get_cfg from detectron2.data import MetadataCatalog from detectron2.engine import DefaultPredictor #Plot settings sns.set_style("darkgrid") palette = sns.color_palette("pastel") sns.set_palette(palette) plt.switch_backend("Agg") # Load the DiT model config cfg = get_cfg() add_vit_config(cfg) cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml") # Get the model weights cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Define the model predictor predictor = DefaultPredictor(cfg) # Load the DePlot model model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(cfg.MODEL.DEVICE) processor = AutoProcessor.from_pretrained("google/deplot") def crop_figure(img: Image.Image , threshold: float = 0.5) -> Image.Image: """Prediction function for the figure cropping model using DiT backend. Args: img (Image.Image): Input document image. threshold (float, optional): Detection threshold. Defaults to 0.5. Returns: (Image.Image): The cropped figure image. """ md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) md.set(thing_classes=["text","title","list","table","figure"]) output = predictor(np.array(img))["instances"] boxes, scores, classes = output.pred_boxes.tensor.cpu().numpy(), output.scores.cpu().numpy(), output.pred_classes.cpu().numpy() boxes = boxes[classes == 4] # 4 is the class for figures scores = scores[classes == 4] if len(boxes) == 0: return [] print(boxes, scores) # sort boxes by score crop_box = boxes[np.argsort(scores)[::-1]][0] # Add white space around the figure margin = 0.1 box_size = crop_box[-2:] - crop_box[:2] size = tuple((box_size + np.array([margin, margin]) * box_size).astype(int)) new = Image.new('RGB', size, (255, 255, 255)) image = img.crop(crop_box) new.paste(image, (int((size[0] - image.size[0]) / 2), int(((size[1]) - image.size[1]) / 2))) return new def extract_tables(image: Image.Image) -> Tuple[str]: """Prediction function for the table extraction model using DePlot backend. Args: image (Image.Image): Input figure image. Returns: Tuple[str]: The table title, the table as a pandas dataframe, and the table as an HTML string, if the table was successfully extracted. """ inputs = processor(image, text="Generate a data table using only the data you see in the graph below: ", return_tensors="pt").to(cfg.MODEL.DEVICE) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=512) decoded = processor.decode(outputs[0], skip_special_tokens=True) print(decoded.replace("<0x0A>", "\n") ) data = [row.split(" | ") for row in decoded.split("<0x0A>")] try: if data[0][0].lower().startswith("title"): title = data[0][1] table = pd.DataFrame(data[2:], columns=data[1]) else: title = "Table" table = pd.DataFrame(data[1:], columns=data[0]) return title, table, table.to_html() except: return "Table", list(list()), decoded.replace("<0x0A>", "\n") def update(df: pd.DataFrame, plot_type: str) -> plt.figure: """Update callback for the gradio interface, that updates the plot based on the table data and selected plot type. Args: df (pd.DataFrame): The extracted table data. plot_type (str): The selected plot type to generate. Returns: plt.figure: The updated plot. """ plt.close("all") df = df.apply(pd.to_numeric, errors="ignore") fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111) cols = df.columns if len(cols) == 0: return fig if len(cols) > 1: df.set_index(cols[0], inplace=True) try: if plot_type == "Line": sns.lineplot(data=df, ax=ax) elif plot_type == "Bar": df = df.reset_index() if len(cols) == 1: sns.barplot(x=df.index, y=df[df.columns[0]], ax=ax) elif len(cols) == 2: sns.barplot(x=df[df.columns[0]], y=df[df.columns[1]], ax=ax) else: df = df.melt(id_vars=cols[0], value_vars=cols[1:], value_name="Value") sns.barplot(x=df[cols[0]], y=df["Value"], hue=df["variable"], ax=ax) elif plot_type == "Scatter": sns.scatterplot(data=df, ax=ax) elif plot_type == "Pie": ax.pie(df[df.columns[0]], labels=df.index, autopct='%1.1f%%', colors=palette) ax.axis('equal') except: pass plt.tight_layout() return fig with gr.Blocks() as demo: gr.Markdown("