Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import TableTransformerForObjectDetection | |
import matplotlib.pyplot as plt | |
from transformers import DetrFeatureExtractor | |
import pandas as pd | |
import uuid | |
from surya.ocr import run_ocr | |
# from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.recognition.model import load_model as load_rec_model | |
from surya.model.recognition.processor import load_processor as load_rec_processor | |
from PIL import ImageDraw, Image | |
import os | |
from pdf2image import convert_from_path | |
import tempfile | |
from ultralyticsplus import YOLO, render_result | |
import cv2 | |
import numpy as np | |
from fpdf import FPDF | |
def convert_pdf_images(pdf_path): | |
# Convert PDF to images | |
images = convert_from_path(pdf_path) | |
# Save each page as a temporary image and collect file paths | |
temp_file_paths = [] | |
for i, page in enumerate(images): | |
# Create a temporary file with a unique name | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
page.save(temp_file.name, 'PNG') # Save the image to the temporary file | |
temp_file_paths.append(temp_file.name) # Add file path to the list | |
return temp_file_paths[0] # Return the list of temporary file paths | |
# Load model | |
model_yolo = YOLO('keremberke/yolov8m-table-extraction') | |
# Set model parameters | |
model_yolo.overrides['conf'] = 0.25 # NMS confidence threshold | |
model_yolo.overrides['iou'] = 0.45 # NMS IoU threshold | |
model_yolo.overrides['agnostic_nms'] = False # NMS class-agnostic | |
model_yolo.overrides['max_det'] = 1000 # maximum number of detections per image | |
def resize_image(image, max_dimension=4200, min_dimension=50): | |
width, height = image.size | |
# Check if the dimensions are within range | |
if width > max_dimension or height > max_dimension or width < min_dimension or height < min_dimension: | |
scaling_factor = min(max_dimension / max(width, height), min_dimension / min(width, height)) | |
new_width = int(width * scaling_factor) | |
new_height = int(height * scaling_factor) | |
return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
return image | |
def crop_table(filename): | |
# Set image | |
image_path = filename | |
image = Image.open(image_path) | |
image_np = np.array(image) | |
# Perform inference | |
results = model_yolo.predict(image_path) | |
# Extract the first bounding box (assuming there's only one table) | |
bbox = results[0].boxes[0] | |
x1, y1, x2, y2 = map(int, bbox.xyxy[0]) # Get the bounding box coordinates | |
# Crop the image using the bounding box coordinates | |
cropped_image = image_np[y1:y2, x1:x2] | |
# Convert the cropped image to RGB (if it's not already in RGB) | |
cropped_image_rgb = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB) | |
# Save the cropped image as a PDF | |
cropped_image_pil = Image.fromarray(cropped_image_rgb) | |
# Save the cropped image to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
cropped_image_pil.save(temp_file.name) | |
return temp_file.name | |
# new v1.1 checkpoints require no timm anymore | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
langs = ["en","th"] # Replace with your languages - optional but recommended | |
det_processor, det_model = load_det_processor(), load_det_model() | |
rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
feature_extractor = DetrFeatureExtractor() | |
model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all") | |
def compute_boxes(image_path): | |
image = Image.open(image_path).convert("RGB") | |
width, height = image.size | |
encoding = feature_extractor(image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0] | |
boxes = results['boxes'].tolist() | |
labels = results['labels'].tolist() | |
return boxes,labels | |
def extract_table(image_path): | |
image = Image.open(image_path) | |
boxes,labels = compute_boxes(image_path) | |
cropped_table_visualized = image.copy() | |
draw = ImageDraw.Draw(cropped_table_visualized) | |
for cell in boxes: | |
draw.rectangle(cell, outline="red") | |
bbox_table = f"{str(uuid.uuid4())}.png" | |
cropped_table_visualized.save(bbox_table) | |
cell_locations = [] | |
for box_row, label_row in zip(boxes, labels): | |
if label_row == 2: | |
for box_col, label_col in zip(boxes, labels): | |
if label_col == 1: | |
cell_box = (box_col[0], box_row[1], box_col[2], box_row[3]) | |
cell_locations.append(cell_box) | |
cell_locations.sort(key=lambda x: (x[1], x[0])) | |
num_columns = 0 | |
box_old = cell_locations[0] | |
for box in cell_locations[1:]: | |
x1, y1, x2, y2 = box | |
x1_old, y1_old, x2_old, y2_old = box_old | |
num_columns += 1 | |
if y1 > y1_old: | |
break | |
box_old = box | |
headers = [] | |
for box in cell_locations[:num_columns]: | |
x1, y1, x2, y2 = box | |
cell_image = resize_image(image.crop((x1, y1, x2, y2))) | |
# new_width = cell_image.width *4 | |
# new_height = cell_image.height *4 | |
# cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS) | |
# cell_text = pytesseract.image_to_string(cell_image, lang='tha+eng') | |
# print(cell_text) | |
plt.figure() | |
plt.imshow(cell_image) | |
plt.axis("off") | |
plt.title("Cropped Cell Image") | |
plt.show() | |
predictions = run_ocr([cell_image], [langs], det_model, det_processor, rec_model, rec_processor) | |
texts = [line.text for line in predictions[0].text_lines] | |
all_text = ' '.join(texts) | |
print(all_text) | |
if all_text: | |
headers.append(all_text) | |
else: | |
headers.append('') | |
df = pd.DataFrame(columns=headers) | |
row = [] | |
for box in cell_locations[num_columns:]: | |
x1, y1, x2, y2 = box | |
cell_image = resize_image(image.crop((x1, y1, x2, y2))) | |
# new_width = cell_image.width * 4 | |
# new_height = cell_image.height * 4 | |
# cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS) | |
# cell_text = pytesseract.image_to_string(cell_image, lang='tha+eng') | |
# print(cell_text) | |
plt.figure() | |
plt.imshow(cell_image) | |
plt.axis("off") | |
plt.title("Cropped Cell Image") | |
plt.show() | |
predictions = run_ocr([cell_image], [langs], det_model, det_processor, rec_model, rec_processor) | |
texts = [line.text for line in predictions[0].text_lines] | |
all_text = ''.join(texts) | |
print(all_text) | |
if all_text: | |
headers.append(all_text) | |
else: | |
headers.append('') | |
row.append(all_text) | |
if len(row) == num_columns: | |
df.loc[len(df)] = row | |
print(row) | |
row = [] | |
filepath = f"{str(uuid.uuid4())}.csv" | |
df.to_csv(filepath, index=False) | |
return filepath, bbox_table | |
# Function to process the uploaded file | |
def process_file(uploaded_file): | |
images_table = convert_pdf_images(uploaded_file) | |
croped_table = crop_table(images_table) | |
filepath, bbox_table = extract_table(croped_table) | |
os.remove(images_table) | |
os.remove(croped_table) | |
return filepath, bbox_table # Return the file path for download | |
# Function to clear the inputs and outputs | |
def clear_inputs(): | |
return None, None, None # Clear both input and output | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Upload a PDF, Process it, and Download the Processed File") | |
with gr.Row(): | |
upload = gr.File(label="Upload PDF", type="filepath", file_types=[".pdf"]) | |
download = gr.File(label="Download Processed PDF") | |
with gr.Row(): | |
process_button = gr.Button("Process") | |
clear_button = gr.Button("Clear") # Custom clear button | |
image_display = gr.Image(label="Processed Image") | |
# Trigger the file processing with the button click | |
process_button.click(process_file, inputs=upload, outputs=[download, image_display]) | |
# Trigger clearing inputs and outputs | |
clear_button.click(clear_inputs, inputs=None, outputs=[upload, download, image_display]) | |
# Launch the interface | |
demo.launch() | |
# print(process_file("/content/give me a example table - give me a example table.pdf")) |