File size: 4,475 Bytes
0572c8f
 
 
 
6fa980c
0572c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e0cdd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0572c8f
 
4e0cdd6
0572c8f
 
 
 
 
4e0cdd6
0572c8f
df692d6
0572c8f
 
 
 
 
 
400411f
 
0572c8f
 
4e0cdd6
df692d6
 
0572c8f
 
400411f
 
0572c8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
from PIL import Image, ImageDraw
import numpy as np
import csv
import pandas as pd

from torchvision import transforms

from transformers import AutoModelForObjectDetection
import torch

import easyocr

import gradio as gr


device = "cuda" if torch.cuda.is_available() else "cpu"


class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
        
        return resized_image

detection_transform = transforms.Compose([
    MaxResize(800),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# load table detection model
# processor = TableTransformerImageProcessor(max_size=800)
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device)

# load table structure recognition model
# structure_processor = TableTransformerImageProcessor(max_size=1000)
structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device)

# load EasyOCR reader
reader = easyocr.Reader(['en'])


def outputs_to_objects(outputs, img_size, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
    pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == 'no object':
            objects.append({'label': class_label, 'score': float(score),
                            'bbox': [float(elem) for elem in bbox]})
    return objects


def detect_and_crop_table(image):
    # prepare image for the model
    # pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = detection_transform(image).unsqueeze(0).to(device)

    # forward pass
    with torch.no_grad():
        outputs = model(pixel_values)

    # postprocess to get detected tables
    id2label = model.config.id2label
    id2label[len(model.config.id2label)] = "no object"
    detected_tables = outputs_to_objects(outputs, image.size, id2label)

    # visualize
    # fig = visualize_detected_tables(image, detected_tables)
    # image = fig2img(fig)

    # crop first detected table out of image
    cropped_table = image.crop(detected_tables[0]["bbox"])

    return cropped_table


def process_pdf():
    print('process_pdf')
    cropped_table = detect_and_crop_table(image)
    # image, cells = recognize_table(cropped_table)

    # cell_coordinates = get_cell_coordinates_by_row(cells)
    # df, data = apply_ocr(cell_coordinates, image)

    return cropped_table
    # return image, df, data
    # return [], [], []

title = "Sheriff's Demo: Table Detection & Recognition with Table Transformer (TATR)."
description = """A demo by M Sheriff for table extraction with the Table Transformer.
First, table detection is performed on the input image using https://huggingface.co/microsoft/table-transformer-detection,
after which the detected table is extracted and https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all recognizes the
    individual rows, columns and cells. OCR is then performed per cell, row by row."""

# examples = [['image.png'], ['mistral_paper.png']]

app = gr.Interface(fn=process_pdf, 
                     inputs=gr.Image(type="pil"),
                     outputs=[gr.Image(type="pil", label="Detected table")],
                     # outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
                     title=title,
                     description=description,
                     # examples=examples
                  )
app.queue()
app.launch(debug=True)