KoonJamesZ commited on
Commit
0157c4f
·
verified ·
1 Parent(s): 22b3d7d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import TableTransformerForObjectDetection
4
+ import matplotlib.pyplot as plt
5
+ from transformers import DetrFeatureExtractor
6
+ import pandas as pd
7
+ import uuid
8
+ from surya.ocr import run_ocr
9
+ # from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
10
+ from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
11
+ from surya.model.recognition.model import load_model as load_rec_model
12
+ from surya.model.recognition.processor import load_processor as load_rec_processor
13
+ from PIL import ImageDraw, Image
14
+ import os
15
+ from pdf2image import convert_from_path
16
+ import tempfile
17
+ from ultralyticsplus import YOLO, render_result
18
+ import cv2
19
+ import numpy as np
20
+ from fpdf import FPDF
21
+
22
+ def convert_pdf_images(pdf_path):
23
+ # Convert PDF to images
24
+ images = convert_from_path(pdf_path)
25
+
26
+ # Save each page as a temporary image and collect file paths
27
+ temp_file_paths = []
28
+ for i, page in enumerate(images):
29
+ # Create a temporary file with a unique name
30
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
31
+ page.save(temp_file.name, 'PNG') # Save the image to the temporary file
32
+ temp_file_paths.append(temp_file.name) # Add file path to the list
33
+
34
+ return temp_file_paths[0] # Return the list of temporary file paths
35
+
36
+
37
+ # Load model
38
+ model_yolo = YOLO('keremberke/yolov8m-table-extraction')
39
+
40
+ # Set model parameters
41
+ model_yolo.overrides['conf'] = 0.25 # NMS confidence threshold
42
+ model_yolo.overrides['iou'] = 0.45 # NMS IoU threshold
43
+ model_yolo.overrides['agnostic_nms'] = False # NMS class-agnostic
44
+ model_yolo.overrides['max_det'] = 1000 # maximum number of detections per image
45
+
46
+ # new v1.1 checkpoints require no timm anymore
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ langs = ["en","th"] # Replace with your languages - optional but recommended
49
+ det_processor, det_model = load_det_processor(), load_det_model()
50
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
51
+
52
+
53
+ feature_extractor = DetrFeatureExtractor()
54
+
55
+ model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all")
56
+
57
+ def crop_table(filename):
58
+ # Set image
59
+ image_path = filename
60
+ image = Image.open(image_path)
61
+ image_np = np.array(image)
62
+
63
+ # Perform inference
64
+ results = model_yolo.predict(image_path)
65
+
66
+ # Extract the first bounding box (assuming there's only one table)
67
+ bbox = results[0].boxes[0]
68
+ x1, y1, x2, y2 = map(int, bbox.xyxy[0]) # Get the bounding box coordinates
69
+
70
+ # Crop the image using the bounding box coordinates
71
+ cropped_image = image_np[y1:y2, x1:x2]
72
+
73
+ # Convert the cropped image to RGB (if it's not already in RGB)
74
+ cropped_image_rgb = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
75
+
76
+ # Save the cropped image as a PDF
77
+ cropped_image_pil = Image.fromarray(cropped_image_rgb)
78
+ # Save the cropped image to a temporary file
79
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
80
+ cropped_image_pil.save(temp_file.name)
81
+
82
+ return temp_file.name
83
+
84
+ def extract_table(image_path):
85
+ image = Image.open(image_path)
86
+ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
87
+ objs = []
88
+ for t in predictions[0].text_lines:
89
+ objs.append([t.polygon,t.confidence,t.text,t.bbox])
90
+ # Sort objects by their y-coordinate to facilitate row separation
91
+ objs = sorted(objs, key=lambda x: x[3][1])
92
+
93
+ # Initialize lists to store rows and column boundaries
94
+ rows = []
95
+ row_threshold = 5 # Adjust as needed to separate rows based on y-coordinates
96
+ column_boundaries = []
97
+
98
+ # First pass to determine approximate column boundaries based on x-coordinates
99
+ for obj in objs:
100
+ x_min = obj[3][0] # x-coordinate of the left side of the bounding box
101
+ if not any(abs(x - x_min) < 10 for x in column_boundaries):
102
+ column_boundaries.append(x_min)
103
+
104
+ # Sort column boundaries to ensure proper left-to-right order
105
+ column_boundaries.sort()
106
+
107
+ # Second pass to organize text by rows and columns
108
+ current_row = []
109
+ previous_y = None
110
+
111
+ for obj in objs:
112
+ bbox = obj[3]
113
+ text = obj[2]
114
+
115
+ # Check if the current item belongs to a new row based on y-coordinate
116
+ if previous_y is None or abs(bbox[1] - previous_y) > row_threshold:
117
+ # Add the completed row to the list if it's not empty
118
+ if current_row:
119
+ rows.append(current_row)
120
+ current_row = [''] * len(column_boundaries) # Initialize new row with placeholders
121
+
122
+ # Find the appropriate column for the current text based on x-coordinate
123
+ for col_index, x_bound in enumerate(column_boundaries):
124
+ if abs(bbox[0] - x_bound) < 10: # Adjust threshold as necessary
125
+ current_row[col_index] = text
126
+
127
+ break
128
+
129
+ previous_y = bbox[1]
130
+
131
+ # Add the last row if it's not empty
132
+ if current_row:
133
+ rows.append(current_row)
134
+
135
+ # Create DataFrame from rows
136
+ df = pd.DataFrame(rows)
137
+ df.columns = df.iloc[0]
138
+ df = df.iloc[1:]
139
+ # Save DataFrame to an CSV file
140
+ csv_path = f'{uuid.uuid4()}.csv'
141
+
142
+ df.to_csv(csv_path,index=False)
143
+
144
+ # Save table_with_bbox_path
145
+ table_with_bbox_path = f"{uuid.uuid4()}.png"
146
+
147
+ for obj in objs:
148
+ # draw bbox on image
149
+ draw = ImageDraw.Draw(image)
150
+ draw.rectangle(obj[3], outline='red', width=1)
151
+ image.save(table_with_bbox_path)
152
+
153
+ return csv_path,table_with_bbox_path
154
+
155
+
156
+
157
+ # Function to process the uploaded file
158
+ def process_file(uploaded_file):
159
+ images_table = convert_pdf_images(uploaded_file)
160
+ croped_table = crop_table(images_table)
161
+
162
+ filepath,bbox_table= extract_table(croped_table)
163
+
164
+ os.remove(images_table)
165
+ os.remove(croped_table)
166
+ return filepath, bbox_table # Return the file path for download
167
+
168
+ # Function to clear the inputs and outputs
169
+ def clear_inputs():
170
+ return None, None, None # Clear both input and output
171
+
172
+ # Define the Gradio interface
173
+ with gr.Blocks() as demo:
174
+ gr.Markdown("## Upload a PDF, Process it, and Download the Processed File")
175
+
176
+ with gr.Row():
177
+ upload = gr.File(label="Upload PDF", type="filepath", file_types=[".pdf"])
178
+ download = gr.File(label="Download Processed PDF")
179
+ with gr.Row():
180
+ process_button = gr.Button("Process")
181
+ clear_button = gr.Button("Clear") # Custom clear button
182
+ image_display = gr.Image(label="Processed Image")
183
+
184
+ # Trigger the file processing with the button click
185
+ process_button.click(process_file, inputs=upload, outputs=[download, image_display])
186
+
187
+ # Trigger clearing inputs and outputs
188
+ clear_button.click(clear_inputs, inputs=None, outputs=[upload, download, image_display])
189
+
190
+ # Launch the interface
191
+ demo.launch()
192
+
193
+ # print(process_file("/content/ขอ ตารางกริยาช่องที่ 1 ในภาษาไทย (กริยาคำกริยา) ซ... - ขอ ตารางกริยาช่องที่ 1 ในภาษาไทย (กริย.pdf"))