aagamjtdev commited on
Commit
58a2a88
·
1 Parent(s): bbc2086

add pipeline helper file

Browse files
Files changed (1) hide show
  1. working_yolo_pipeline.py +1045 -0
working_yolo_pipeline.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ import re
5
+ import torch
6
+ import torch.nn as nn
7
+ from TorchCRF import CRF
8
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config
9
+ from typing import List, Dict, Any, Optional, Union, Tuple
10
+ import fitz # PyMuPDF
11
+ import numpy as np
12
+ import cv2
13
+ from ultralytics import YOLO
14
+ import glob
15
+ import pytesseract
16
+ from PIL import Image
17
+ from scipy.signal import find_peaks
18
+ from scipy.ndimage import gaussian_filter1d
19
+ import sys
20
+ import io
21
+ import base64
22
+ import tempfile # Recommended for robust temporary file handling
23
+
24
+ # ============================================================================
25
+ # --- CONFIGURATION AND CONSTANTS ---
26
+ # ============================================================================
27
+
28
+ # NOTE: Update these paths to match your environment before running!
29
+ WEIGHTS_PATH = '/home/dipesh/Downloads/api-mcq/YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt'
30
+ DEFAULT_LAYOUTLMV3_MODEL_PATH = "checkpoints/layoutlmv3_trained_20251031_102846_recovered.pth"
31
+
32
+ # DIRECTORY CONFIGURATION
33
+ # NOTE: These are now used for temporary data extraction/storage
34
+ OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' # Still needed for Phase 1 output
35
+ FIGURE_EXTRACTION_DIR = './figure_extraction'
36
+ TEMP_IMAGE_DIR = './temp_pdf_images'
37
+
38
+ # Detection parameters
39
+ CONF_THRESHOLD = 0.2
40
+ TARGET_CLASSES = ['figure', 'equation']
41
+ IOU_MERGE_THRESHOLD = 0.4
42
+ IOA_SUPPRESSION_THRESHOLD = 0.7
43
+ LINE_TOLERANCE = 15
44
+
45
+ # Global counters for sequential numbering across the entire PDF
46
+ GLOBAL_FIGURE_COUNT = 0
47
+ GLOBAL_EQUATION_COUNT = 0
48
+
49
+ # LayoutLMv3 Labels
50
+ ID_TO_LABEL = {
51
+ 0: "O",
52
+ 1: "B-QUESTION", 2: "I-QUESTION",
53
+ 3: "B-OPTION", 4: "I-OPTION",
54
+ 5: "B-ANSWER", 6: "I-ANSWER",
55
+ 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING",
56
+ 9: "B-PASSAGE", 10: "I-PASSAGE"
57
+ }
58
+ NUM_LABELS = len(ID_TO_LABEL)
59
+
60
+
61
+ # ============================================================================
62
+ # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS (Word Extraction) ---
63
+ # --- (Includes all necessary helper functions from the first prompt) ---
64
+ # ============================================================================
65
+
66
+ def calculate_iou(box1, box2):
67
+ x1_a, y1_a, x2_a, y2_a = box1
68
+ x1_b, y1_b, x2_b, y2_b = box2
69
+ x_left = max(x1_a, x1_b)
70
+ y_top = max(y1_a, y1_b)
71
+ x_right = min(x2_a, x2_b)
72
+ y_bottom = min(y2_a, y2_b)
73
+ intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
74
+ box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
75
+ box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
76
+ union_area = float(box_a_area + box_b_area - intersection_area)
77
+ return intersection_area / union_area if union_area > 0 else 0
78
+
79
+
80
+ def calculate_ioa(box1, box2):
81
+ x1_a, y1_a, x2_a, y2_a = box1
82
+ x1_b, y1_b, x2_b, y2_b = box2
83
+ x_left = max(x1_a, x1_b)
84
+ y_top = max(y1_a, y1_b)
85
+ x_right = min(x2_a, x2_b)
86
+ y_bottom = min(y2_a, y2_b)
87
+ intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
88
+ box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
89
+ return intersection_area / box_a_area if box_a_area > 0 else 0
90
+
91
+
92
+ def merge_overlapping_boxes(detections, iou_threshold):
93
+ if not detections: return []
94
+ detections.sort(key=lambda d: d['conf'], reverse=True)
95
+ merged_detections = []
96
+ is_merged = [False] * len(detections)
97
+ for i in range(len(detections)):
98
+ if is_merged[i]: continue
99
+ current_box = detections[i]['coords']
100
+ current_class = detections[i]['class']
101
+ merged_x1, merged_y1, merged_x2, merged_y2 = current_box
102
+ for j in range(i + 1, len(detections)):
103
+ if is_merged[j] or detections[j]['class'] != current_class: continue
104
+ other_box = detections[j]['coords']
105
+ iou = calculate_iou(current_box, other_box)
106
+ if iou > iou_threshold:
107
+ merged_x1 = min(merged_x1, other_box[0])
108
+ merged_y1 = min(merged_y1, other_box[1])
109
+ merged_x2 = max(merged_x2, other_box[2])
110
+ merged_y2 = max(merged_y2, other_box[3])
111
+ is_merged[j] = True
112
+ merged_detections.append({
113
+ 'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
114
+ 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
115
+ })
116
+ return merged_detections
117
+
118
+
119
+ def pdf_to_images(pdf_path, temp_dir):
120
+ print("\n[YOLO/OCR STEP 1.1: PDF CONVERSION]")
121
+ try:
122
+ doc = fitz.open(pdf_path)
123
+ pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
124
+ image_paths = []
125
+ mat = fitz.Matrix(2.0, 2.0)
126
+ for page_num in range(doc.page_count):
127
+ page = doc.load_page(page_num)
128
+ pix = page.get_pixmap(matrix=mat)
129
+ img_filename = f"{pdf_name}_page{page_num + 1}.png"
130
+ img_path = os.path.join(temp_dir, img_filename)
131
+ pix.save(img_path)
132
+ image_paths.append(img_path)
133
+ doc.close()
134
+ print(f" ✅ PDF Conversion complete. {len(image_paths)} images generated.")
135
+ return image_paths
136
+ except Exception as e:
137
+ print(f"❌ ERROR processing PDF {pdf_path}: {e}")
138
+ return []
139
+
140
+
141
+ def preprocess_and_ocr_page(image_path, model, pdf_name, page_num):
142
+ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
143
+ page_filename = os.path.basename(image_path)
144
+ original_img = cv2.imread(image_path)
145
+ if original_img is None: return None
146
+
147
+ # --- A. YOLO DETECTION AND MERGING ---
148
+ results = model.predict(source=image_path, conf=CONF_THRESHOLD, imgsz=640, verbose=False)
149
+ relevant_detections = []
150
+ if results and results[0].boxes:
151
+ for box in results[0].boxes:
152
+ class_id = int(box.cls[0])
153
+ class_name = model.names[class_id]
154
+ if class_name in TARGET_CLASSES:
155
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
156
+ relevant_detections.append(
157
+ {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])})
158
+
159
+ merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD)
160
+
161
+ # --- B. COMPONENT EXTRACTION AND TAGGING ---
162
+ component_metadata = []
163
+ for detection in merged_detections:
164
+ x1, y1, x2, y2 = detection['coords']
165
+ class_name = detection['class']
166
+
167
+ if class_name == 'figure':
168
+ GLOBAL_FIGURE_COUNT += 1
169
+ counter = GLOBAL_FIGURE_COUNT
170
+ component_word = f"FIGURE{counter}"
171
+ elif class_name == 'equation':
172
+ GLOBAL_EQUATION_COUNT += 1
173
+ counter = GLOBAL_EQUATION_COUNT
174
+ component_word = f"EQUATION{counter}"
175
+ else:
176
+ continue
177
+
178
+ component_crop = original_img[y1:y2, x1:x2]
179
+ component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png"
180
+ cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop)
181
+
182
+ y_midpoint = (y1 + y2) // 2
183
+ component_metadata.append({
184
+ 'type': class_name, 'word': component_word,
185
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
186
+ 'y0': int(y_midpoint), 'x0': int(x1)
187
+ })
188
+
189
+ # --- C. TESSERACT OCR ---
190
+ try:
191
+ pil_img = Image.fromarray(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
192
+ hocr_data = pytesseract.image_to_data(pil_img, output_type=pytesseract.Output.DICT)
193
+ raw_ocr_output = []
194
+ for i in range(len(hocr_data['level'])):
195
+ text = hocr_data['text'][i].strip()
196
+ if text and hocr_data['conf'][i] > -1:
197
+ x1 = int(hocr_data['left'][i])
198
+ y1 = int(hocr_data['top'][i])
199
+ x2 = x1 + int(hocr_data['width'][i])
200
+ y2 = y1 + int(hocr_data['height'][i])
201
+ raw_ocr_output.append({
202
+ 'type': 'text', 'word': text, 'confidence': float(hocr_data['conf'][i]),
203
+ 'bbox': [x1, y1, x2, y2], 'y0': y1, 'x0': x1
204
+ })
205
+ except Exception as e:
206
+ print(f" ❌ Tesseract OCR Error on {page_filename}: {e}")
207
+ return None
208
+
209
+ # --- D. OCR CLEANING AND MERGING (Using IoA) ---
210
+ items_to_sort = []
211
+ for ocr_word in raw_ocr_output:
212
+ is_suppressed = False
213
+ for component in component_metadata:
214
+ ioa = calculate_ioa(ocr_word['bbox'], component['bbox'])
215
+ if ioa > IOA_SUPPRESSION_THRESHOLD:
216
+ is_suppressed = True
217
+ break
218
+ if not is_suppressed:
219
+ items_to_sort.append(ocr_word)
220
+
221
+ items_to_sort.extend(component_metadata)
222
+
223
+ # --- E. SOPHISTICATED LINE-BASED SORTING ---
224
+ items_to_sort.sort(key=lambda x: (x['y0'], x['x0']))
225
+ lines = []
226
+ for item in items_to_sort:
227
+ placed = False
228
+ for line in lines:
229
+ y_ref = min(it['y0'] for it in line)
230
+ if abs(y_ref - item['y0']) < LINE_TOLERANCE:
231
+ line.append(item)
232
+ placed = True
233
+ break
234
+ if not placed and item['type'] in ['equation', 'figure']:
235
+ for line in lines:
236
+ y_ref = min(it['y0'] for it in line)
237
+ if abs(y_ref - item['y0']) < 20:
238
+ line.append(item)
239
+ placed = True
240
+ break
241
+ if not placed:
242
+ lines.append([item])
243
+
244
+ for line in lines:
245
+ line.sort(key=lambda x: x['x0'])
246
+
247
+ final_output = []
248
+ for line in lines:
249
+ for item in line:
250
+ data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]}
251
+ if 'tag' in item: data_item['tag'] = item['tag']
252
+ if 'confidence' in item: data_item['confidence'] = item['confidence']
253
+ final_output.append(data_item)
254
+
255
+ return final_output
256
+
257
+
258
+ def get_word_data_for_detection(page: fitz.Page, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list:
259
+ word_data = page.get_text("words")
260
+ if len(word_data) == 0:
261
+ try:
262
+ pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
263
+ img_bytes = pix.tobytes("png")
264
+ img = Image.open(io.BytesIO(img_bytes))
265
+ data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
266
+ full_word_data = []
267
+ for i in range(len(data['level'])):
268
+ if data['text'][i].strip():
269
+ x1, y1 = data['left'][i] / 3, data['top'][i] / 3
270
+ x2, y2 = x1 + data['width'][i] / 3, y1 + data['height'][i] / 3
271
+ full_word_data.append((data['text'][i], x1, y1, x2, y2))
272
+ word_data = full_word_data
273
+ except Exception:
274
+ return []
275
+ else:
276
+ word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data]
277
+
278
+ page_height = page.rect.height
279
+ y_min = page_height * top_margin_percent
280
+ y_max = page_height * (1 - bottom_margin_percent)
281
+ return [d for d in word_data if d[2] >= y_min and d[4] <= y_max]
282
+
283
+
284
+ def calculate_x_gutters(word_data: list, params: Dict) -> List[int]:
285
+ if not word_data: return []
286
+ x_points = []
287
+ for _, x1, _, x2, _ in word_data: x_points.extend([x1, x2])
288
+ max_x = max(x_points)
289
+ bin_size = params['cluster_bin_size']
290
+ num_bins = int(np.ceil(max_x / bin_size))
291
+ hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x))
292
+ smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=params['cluster_smoothing'])
293
+ inverted_signal = np.max(smoothed_hist) - smoothed_hist
294
+
295
+ peaks, properties = find_peaks(
296
+ inverted_signal, height=0, distance=params['cluster_min_width'] / bin_size
297
+ )
298
+
299
+ if not peaks.size: return []
300
+
301
+ threshold_value = np.percentile(smoothed_hist, params['cluster_threshold_percentile'])
302
+ inverted_threshold = np.max(smoothed_hist) - threshold_value
303
+ significant_peaks = peaks[properties['peak_heights'] >= inverted_threshold]
304
+ separator_x_coords = [int(bin_edges[p]) for p in significant_peaks]
305
+
306
+ final_separators = []
307
+ prominence_threshold = params['cluster_prominence'] * np.max(smoothed_hist)
308
+
309
+ for x_coord in separator_x_coords:
310
+ bin_idx = np.searchsorted(bin_edges, x_coord) - 1
311
+ window_size = int(params['cluster_min_width'] / bin_size)
312
+
313
+ left_start, left_end = max(0, bin_idx - window_size), bin_idx
314
+ right_start, right_end = bin_idx + 1, min(len(smoothed_hist), bin_idx + 1 + window_size)
315
+
316
+ if left_end <= left_start or right_end <= right_start: continue
317
+
318
+ avg_left_density = np.mean(smoothed_hist[left_start:left_end])
319
+ avg_right_density = np.mean(smoothed_hist[right_start:right_end])
320
+
321
+ if avg_left_density >= prominence_threshold and avg_right_density >= prominence_threshold:
322
+ final_separators.append(x_coord)
323
+
324
+ return sorted(final_separators)
325
+
326
+
327
+ def detect_column_gutters(pdf_path: str, page_num: int, **params) -> Optional[int]:
328
+ try:
329
+ doc = fitz.open(pdf_path)
330
+ page = doc.load_page(page_num)
331
+ word_data = get_word_data_for_detection(page, params.get('top_margin_percent', 0.10),
332
+ params.get('bottom_margin_percent', 0.10))
333
+ doc.close()
334
+ if not word_data: return None
335
+
336
+ separators = calculate_x_gutters(word_data, params)
337
+ if len(separators) == 1:
338
+ return separators[0]
339
+ elif len(separators) > 1:
340
+ page_width = page.rect.width
341
+ center_x = page_width / 2
342
+ return min(separators, key=lambda x: abs(x - center_x))
343
+ return None
344
+ except Exception:
345
+ return None
346
+
347
+
348
+ def _merge_integrity(all_words_by_page: List[str], all_bboxes_raw: List[List[int]],
349
+ column_separator_x: Optional[int]) -> List[List[str]]:
350
+ if column_separator_x is None: return [all_words_by_page]
351
+ left_column_words, right_column_words = [], []
352
+ for word, bbox_raw in zip(all_words_by_page, all_bboxes_raw):
353
+ center_x = (bbox_raw[0] + bbox_raw[2]) / 2
354
+ if center_x < column_separator_x:
355
+ left_column_words.append(word)
356
+ else:
357
+ right_column_words.append(word)
358
+ return [c for c in [left_column_words, right_column_words] if c]
359
+
360
+
361
+ def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]:
362
+ """Runs the YOLO/OCR pipeline and returns the path to the combined JSON output."""
363
+ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
364
+
365
+ # Reset globals for a new PDF run
366
+ GLOBAL_FIGURE_COUNT = 0
367
+ GLOBAL_EQUATION_COUNT = 0
368
+
369
+ print("\n" + "=" * 80)
370
+ print("--- 1. STARTING YOLO/OCR PREPROCESSING PIPELINE ---")
371
+ print("=" * 80)
372
+
373
+ if not os.path.exists(pdf_path):
374
+ print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.")
375
+ return None
376
+ if not os.path.exists(WEIGHTS_PATH):
377
+ print(f"❌ FATAL ERROR: YOLO Weights not found at {WEIGHTS_PATH}.")
378
+ return None
379
+
380
+ # Ensure required directories exist
381
+ os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True)
382
+ os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True)
383
+ os.makedirs(TEMP_IMAGE_DIR, exist_ok=True)
384
+
385
+ model = YOLO(WEIGHTS_PATH)
386
+
387
+ pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
388
+
389
+ all_pages_data = []
390
+ image_paths = pdf_to_images(pdf_path, TEMP_IMAGE_DIR)
391
+
392
+ if not image_paths:
393
+ print(f"❌ Pipeline halted. Could not convert any pages from PDF.")
394
+ return None
395
+
396
+ print("\n[STEP 1.2: ITERATING PAGES AND RUNNING YOLO/OCR]")
397
+ total_pages_processed = 0
398
+ for i, image_path in enumerate(image_paths):
399
+ page_num = i + 1
400
+ print(f" -> Processing Page {page_num}/{len(image_paths)}...")
401
+
402
+ final_output = preprocess_and_ocr_page(image_path, model, pdf_name, page_num)
403
+
404
+ if final_output is not None:
405
+ page_data = {"page_number": page_num, "data": final_output}
406
+ all_pages_data.append(page_data)
407
+ total_pages_processed += 1
408
+ else:
409
+ print(f" ❌ Skipped page {page_num} due to processing error.")
410
+
411
+ # --- FINAL SAVE STEP ---
412
+ if all_pages_data:
413
+ try:
414
+ with open(preprocessed_json_path, 'w') as f:
415
+ json.dump(all_pages_data, f, indent=4)
416
+ print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}")
417
+ except Exception as e:
418
+ print(f"❌ ERROR saving combined JSON output: {e}")
419
+ return None
420
+ else:
421
+ print("❌ WARNING: No page data generated. Halting pipeline.")
422
+ return None
423
+
424
+ print("\n" + "=" * 80)
425
+ print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---")
426
+ print("=" * 80)
427
+
428
+ return preprocessed_json_path
429
+
430
+
431
+ # ============================================================================
432
+ # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS (Raw BIO Tagging) ---
433
+ # ============================================================================
434
+
435
+ class LayoutLMv3ForTokenClassification(nn.Module):
436
+ def __init__(self, num_labels: int = NUM_LABELS):
437
+ super().__init__()
438
+ self.num_labels = num_labels
439
+ config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
440
+ self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config)
441
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
442
+ self.crf = CRF(num_labels)
443
+ self.init_weights()
444
+
445
+ def init_weights(self):
446
+ nn.init.xavier_uniform_(self.classifier.weight)
447
+ if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias)
448
+
449
+ def forward(
450
+ self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor,
451
+ labels: Optional[torch.Tensor] = None,
452
+ ) -> Union[torch.Tensor, Tuple[List[List[int]], Any]]:
453
+ outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True)
454
+ sequence_output = outputs.last_hidden_state
455
+ emissions = self.classifier(sequence_output)
456
+ mask = attention_mask.bool()
457
+ if labels is not None:
458
+ loss = -self.crf(emissions, labels, mask=mask).mean()
459
+ return loss
460
+ else:
461
+ return self.crf.viterbi_decode(emissions, mask=mask)
462
+
463
+
464
+ def run_inference_and_get_raw_words(pdf_path: str, model_path: str,
465
+ preprocessed_json_path: str,
466
+ column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]:
467
+ """Runs LayoutLMv3-CRF inference and returns the raw word-level predictions, grouped by page."""
468
+ print("\n" + "=" * 80)
469
+ print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE ---")
470
+ print("=" * 80)
471
+
472
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
473
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
474
+
475
+ try:
476
+ model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS)
477
+ checkpoint = torch.load(model_path, map_location=device)
478
+ model_state = checkpoint.get('model_state_dict', checkpoint)
479
+ # Fix for potential key mismatch
480
+ fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()}
481
+ model.load_state_dict(fixed_state_dict)
482
+ model.to(device)
483
+ model.eval()
484
+ except Exception as e:
485
+ print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}")
486
+ return []
487
+
488
+ try:
489
+ with open(preprocessed_json_path, 'r', encoding='utf-8') as f:
490
+ preprocessed_data = json.load(f)
491
+ except Exception as e:
492
+ print(f"❌ ERROR loading preprocessed JSON: {e}")
493
+ return []
494
+
495
+ try:
496
+ doc = fitz.open(pdf_path)
497
+ except Exception as e:
498
+ print(f"❌ ERROR loading PDF file: {e}")
499
+ return []
500
+
501
+ final_page_predictions = []
502
+ CHUNK_SIZE = 500
503
+
504
+ for page_data in preprocessed_data:
505
+ page_num_1_based = page_data['page_number']
506
+ page_num_0_based = page_num_1_based - 1
507
+ page_raw_predictions = []
508
+
509
+ fitz_page = doc.load_page(page_num_0_based)
510
+ page_width, page_height = fitz_page.rect.width, fitz_page.rect.height
511
+
512
+ words, bboxes_raw_pdf_space, normalized_bboxes_list = [], [], []
513
+ scale_factor = 2.0
514
+
515
+ for item in page_data['data']:
516
+ word, raw_yolo_bbox = item['word'], item['bbox']
517
+
518
+ bbox_pdf = [
519
+ int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor),
520
+ int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor)
521
+ ]
522
+
523
+ normalized_bbox = [
524
+ max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))),
525
+ max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))),
526
+ max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))),
527
+ max(0, min(1000, int(1000 * bbox_pdf[3] / page_height)))
528
+ ]
529
+
530
+ words.append(word)
531
+ bboxes_raw_pdf_space.append(bbox_pdf)
532
+ normalized_bboxes_list.append(normalized_bbox)
533
+
534
+ if not words: continue
535
+
536
+ column_detection_params = column_detection_params or {}
537
+ column_separator_x = detect_column_gutters(pdf_path, page_num_0_based, **column_detection_params)
538
+
539
+ word_chunks = _merge_integrity(words, bboxes_raw_pdf_space, column_separator_x)
540
+
541
+ # Reworked indexing logic to handle words correctly across chunks and sub-batches
542
+ current_global_index = 0
543
+ for chunk_words_original in word_chunks:
544
+ if not chunk_words_original: continue
545
+
546
+ # Reconstruct the aligned chunk of words and bboxes using the global list
547
+ chunk_words, chunk_normalized_bboxes, chunk_bboxes_pdf = [], [], []
548
+ temp_global_index = current_global_index
549
+ for i in range(len(words)):
550
+ if temp_global_index <= i and words[i] in chunk_words_original:
551
+ # Simple (non-perfect) way to try and grab the words in order from the global list
552
+ # The original script had more complex logic to re-align after splitting.
553
+ # For simplicity, we assume 'words' list matches the combined word order from page_data['data'].
554
+ if words[i] == chunk_words_original[len(chunk_words)]:
555
+ chunk_words.append(words[i])
556
+ chunk_normalized_bboxes.append(normalized_bboxes_list[i])
557
+ chunk_bboxes_pdf.append(bboxes_raw_pdf_space[i])
558
+ current_global_index = i + 1
559
+ if len(chunk_words) == len(chunk_words_original):
560
+ break
561
+
562
+ # --- Inference in sub-batches ---
563
+ for i in range(0, len(chunk_words), CHUNK_SIZE):
564
+ sub_words = chunk_words[i:i + CHUNK_SIZE]
565
+ sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE]
566
+ sub_bboxes_pdf = chunk_bboxes_pdf[i:i + CHUNK_SIZE]
567
+
568
+ # Handling empty input if chunking logic was flawed
569
+ if not sub_words: continue
570
+
571
+ encoded_input = tokenizer(
572
+ sub_words, boxes=sub_bboxes, truncation=True, padding="max_length",
573
+ max_length=512, return_tensors="pt"
574
+ )
575
+
576
+ input_ids = encoded_input['input_ids'].to(device)
577
+ bbox = encoded_input['bbox'].to(device)
578
+ attention_mask = encoded_input['attention_mask'].to(device)
579
+
580
+ with torch.no_grad():
581
+ predictions_int_list = model(input_ids, bbox, attention_mask)
582
+
583
+ if not predictions_int_list: continue
584
+
585
+ predictions_int = predictions_int_list[0]
586
+ word_ids = encoded_input.word_ids()
587
+ word_idx_to_pred_id = {}
588
+
589
+ for token_idx, word_idx in enumerate(word_ids):
590
+ if word_idx is not None and word_idx < len(sub_words):
591
+ # Use the prediction for the first token of a word
592
+ if word_idx not in word_idx_to_pred_id:
593
+ word_idx_to_pred_id[word_idx] = predictions_int[token_idx]
594
+
595
+ for current_word_idx in range(len(sub_words)):
596
+ pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0)
597
+ pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor
598
+ predicted_label = ID_TO_LABEL[pred_id]
599
+
600
+ page_raw_predictions.append({
601
+ "word": sub_words[current_word_idx],
602
+ "bbox": sub_bboxes_pdf[current_word_idx],
603
+ "predicted_label": predicted_label,
604
+ "page_number": page_num_1_based
605
+ })
606
+
607
+ # Ensure the current_global_index is correctly advanced beyond the words in this chunk
608
+ # (Implicitly handled by the logic inside the inner loop, but dangerous. The original script's
609
+ # way of handling the current_original_index was slightly better but complicated the loop)
610
+
611
+ if page_raw_predictions:
612
+ final_page_predictions.append({
613
+ "page_number": page_num_1_based,
614
+ "data": page_raw_predictions
615
+ })
616
+
617
+ doc.close()
618
+ print(f"✅ LayoutLMv3 inference complete. Predicted tags for {len(final_page_predictions)} pages.")
619
+ return final_page_predictions
620
+
621
+
622
+ # ============================================================================
623
+ # --- PHASE 3: BIO TO STRUCTURED JSON DECODER (Modified for In-Memory Return) ---
624
+ # ============================================================================
625
+
626
+ def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]:
627
+ """
628
+ Reads the page-grouped raw word predictions from input_path, flattens them, and converts
629
+ the BIO tags into the structured JSON format. Returns the structured data.
630
+ """
631
+ print("\n" + "=" * 80)
632
+ print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---")
633
+ print("=" * 80)
634
+
635
+ try:
636
+ with open(input_path, 'r', encoding='utf-8') as f:
637
+ predictions_by_page = json.load(f)
638
+ except (json.JSONDecodeError, FileNotFoundError) as e:
639
+ print(f"❌ Error loading raw prediction file '{input_path}': {e}")
640
+ return None
641
+ except Exception as e:
642
+ print(f"❌ An unexpected error occurred during file loading: {e}")
643
+ return None
644
+
645
+ # FLATTEN THE LIST OF WORDS ACROSS ALL PAGES
646
+ predictions = []
647
+ for page_item in predictions_by_page:
648
+ if isinstance(page_item, dict) and 'data' in page_item and isinstance(page_item['data'], list):
649
+ predictions.extend(page_item['data'])
650
+
651
+ if not predictions:
652
+ print("❌ Error: No valid word data found in the input file after attempting to flatten pages.")
653
+ return None
654
+
655
+ # --- Your original parsing logic starts here ---
656
+ structured_data = []
657
+ current_item = None
658
+ current_option_key = None
659
+ current_passage_buffer = []
660
+ current_text_buffer = []
661
+
662
+ first_question_started = False
663
+ last_entity_type = None
664
+
665
+ just_finished_i_option = False
666
+ is_in_new_passage = False
667
+
668
+ def finalize_passage_to_item(item, passage_buffer):
669
+ if passage_buffer:
670
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
671
+ if item.get('passage'):
672
+ item['passage'] += ' ' + passage_text
673
+ else:
674
+ item['passage'] = passage_text
675
+ passage_buffer.clear()
676
+
677
+ for item in predictions:
678
+ word = item['word']
679
+ label = item['predicted_label']
680
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
681
+ current_text_buffer.append(word)
682
+ previous_entity_type = last_entity_type
683
+ is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE')
684
+
685
+ if not first_question_started and label != 'B-QUESTION' and not is_passage_label:
686
+ just_finished_i_option = False
687
+ is_in_new_passage = False
688
+ continue
689
+
690
+ if not first_question_started and is_passage_label:
691
+ if label == 'B-PASSAGE' or label == 'I-PASSAGE' or not current_passage_buffer:
692
+ current_passage_buffer.append(word)
693
+ last_entity_type = 'PASSAGE'
694
+ just_finished_i_option = False
695
+ is_in_new_passage = False
696
+ continue
697
+
698
+ if label == 'B-QUESTION':
699
+ if not first_question_started:
700
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
701
+ if header_text or current_passage_buffer:
702
+ metadata_item = {'type': 'METADATA', 'passage': ''}
703
+ if current_passage_buffer:
704
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
705
+ if header_text:
706
+ metadata_item['text'] = header_text
707
+ elif header_text:
708
+ metadata_item['text'] = header_text
709
+ structured_data.append(metadata_item)
710
+ first_question_started = True
711
+ current_text_buffer = [word]
712
+
713
+ if current_item is not None:
714
+ finalize_passage_to_item(current_item, current_passage_buffer)
715
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
716
+ structured_data.append(current_item)
717
+ current_text_buffer = [word]
718
+
719
+ current_item = {
720
+ 'question': word,
721
+ 'options': {},
722
+ 'answer': '',
723
+ 'passage': '',
724
+ 'text': ''
725
+ }
726
+ current_option_key = None
727
+ last_entity_type = 'QUESTION'
728
+ just_finished_i_option = False
729
+ is_in_new_passage = False
730
+ continue
731
+
732
+ if current_item is not None:
733
+ if is_in_new_passage:
734
+ current_item['new_passage'] += f' {word}'
735
+ if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
736
+ is_in_new_passage = False
737
+ if label.startswith(('B-', 'I-')):
738
+ last_entity_type = entity_type
739
+ continue
740
+
741
+ is_in_new_passage = False
742
+ if label.startswith('B-'):
743
+ if entity_type != 'PASSAGE':
744
+ finalize_passage_to_item(current_item, current_passage_buffer)
745
+ current_passage_buffer = []
746
+ last_entity_type = entity_type
747
+
748
+ if entity_type == 'PASSAGE':
749
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
750
+ current_item['new_passage'] = word
751
+ is_in_new_passage = True
752
+ else:
753
+ current_passage_buffer.append(word)
754
+ elif entity_type == 'OPTION':
755
+ current_option_key = word
756
+ current_item['options'][current_option_key] = word
757
+ just_finished_i_option = False
758
+ elif entity_type == 'ANSWER':
759
+ current_item['answer'] = word
760
+ current_option_key = None
761
+ just_finished_i_option = False
762
+ elif entity_type == 'QUESTION':
763
+ current_item['question'] += f' {word}'
764
+ just_finished_i_option = False
765
+
766
+ elif label.startswith('I-'):
767
+ if entity_type == 'QUESTION' and current_item.get('question'):
768
+ current_item['question'] += f' {word}'
769
+ last_entity_type = 'QUESTION'
770
+ just_finished_i_option = False
771
+ elif entity_type == 'PASSAGE':
772
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
773
+ current_item['new_passage'] = word
774
+ is_in_new_passage = True
775
+ else:
776
+ if last_entity_type == 'QUESTION' and current_item.get('question'):
777
+ last_entity_type = 'PASSAGE'
778
+ if last_entity_type == 'PASSAGE' or not current_passage_buffer:
779
+ current_passage_buffer.append(word)
780
+ last_entity_type = 'PASSAGE'
781
+ just_finished_i_option = False
782
+ elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None:
783
+ current_item['options'][current_option_key] += f' {word}'
784
+ just_finished_i_option = True
785
+ elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER':
786
+ current_item['answer'] += f' {word}'
787
+ just_finished_i_option = False
788
+ else:
789
+ just_finished_i_option = False
790
+
791
+ elif label == 'O':
792
+ if last_entity_type == 'QUESTION' and current_item and 'question' in current_item:
793
+ current_item['question'] += f' {word}'
794
+ just_finished_i_option = False
795
+
796
+ # --- Finalize last item ---
797
+ if current_item is not None:
798
+ finalize_passage_to_item(current_item, current_passage_buffer)
799
+ current_item['text'] = ' '.join(current_text_buffer).strip()
800
+ structured_data.append(current_item)
801
+ elif not structured_data and current_passage_buffer:
802
+ metadata_item = {'type': 'METADATA', 'passage': ''}
803
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
804
+ metadata_item['text'] = ' '.join(current_text_buffer).strip()
805
+ structured_data.append(metadata_item)
806
+
807
+ # --- FINAL CLEANUP ---
808
+ for item in structured_data:
809
+ item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
810
+ if 'new_passage' in item:
811
+ item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
812
+
813
+ # --- SAVE INTERMEDIATE FILE (Optional for Debugging) ---
814
+ try:
815
+ with open(output_path, 'w', encoding='utf-8') as f:
816
+ json.dump(structured_data, f, indent=2, ensure_ascii=False)
817
+ print(f"✅ Decoding complete. Intermediate structured JSON saved to '{output_path}'.")
818
+ except Exception as e:
819
+ print(f"❌ Error saving intermediate output file: {e}. Returning data anyway.")
820
+
821
+ # **KEY CHANGE: RETURN THE DATA STRUCTURE**
822
+ return structured_data
823
+
824
+
825
+ # ============================================================================
826
+ # --- PHASE 4: IMAGE EMBEDDING (Modified for In-Memory Return) ---
827
+ # ============================================================================
828
+
829
+ def get_base64_for_file(filepath: str) -> str:
830
+ """Reads a file and returns its Base64 encoded string."""
831
+ try:
832
+ with open(filepath, 'rb') as f:
833
+ return base64.b64encode(f.read()).decode('utf-8')
834
+ except Exception as e:
835
+ print(f" ❌ Error encoding file {filepath}: {e}")
836
+ return ""
837
+
838
+
839
+ def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[
840
+ Dict[str, Any]]:
841
+ """
842
+ Scans structured data for EQUATION/FIGURE tags, converts corresponding images
843
+ to Base64, and embeds them into the JSON entry in memory.
844
+ """
845
+ print("\n" + "=" * 80)
846
+ print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---")
847
+ print("=" * 80)
848
+
849
+ if not structured_data:
850
+ print("❌ Error: No structured data provided for image embedding.")
851
+ return []
852
+
853
+ # Map image tags (e.g., EQUATION9) to their full file paths
854
+ image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png"))
855
+ image_lookup = {}
856
+ tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
857
+
858
+ for filepath in image_files:
859
+ filename = os.path.basename(filepath)
860
+ match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE)
861
+ if match:
862
+ key = f"{match.group(1).upper()}{match.group(2)}"
863
+ image_lookup[key] = filepath
864
+
865
+ print(f" -> Found {len(image_lookup)} image components in the extraction directory.")
866
+
867
+ # 2. Iterate through structured data and embed images
868
+ final_structured_data = []
869
+
870
+ for item in structured_data:
871
+ text_fields = [item.get('question', ''), item.get('passage', '')]
872
+ if 'options' in item:
873
+ for opt_val in item['options'].values():
874
+ text_fields.append(opt_val)
875
+ if 'new_passage' in item:
876
+ text_fields.append(item['new_passage'])
877
+
878
+ unique_tags_to_embed = set()
879
+
880
+ for text in text_fields:
881
+ if not text: continue
882
+ for match in tag_regex.finditer(text):
883
+ tag = match.group(0).upper()
884
+ if tag in image_lookup:
885
+ unique_tags_to_embed.add(tag)
886
+
887
+ # 3. Embed the Base64 images
888
+ for tag in sorted(list(unique_tags_to_embed)):
889
+ filepath = image_lookup[tag]
890
+ base64_code = get_base64_for_file(filepath)
891
+ base_key = tag.replace(' ', '').lower()
892
+ item[base_key] = base64_code
893
+
894
+ final_structured_data.append(item)
895
+
896
+ print(f"✅ Image embedding complete. Returning final structured data.")
897
+ return final_structured_data
898
+
899
+
900
+ # ============================================================================
901
+ # --- MAIN FUNCTION (The Callable Interface) ---
902
+ # ============================================================================
903
+
904
+ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[List[Dict[str, Any]]]:
905
+ """
906
+ Executes the full document analysis pipeline: YOLO/OCR -> LayoutLMv3 -> Structured JSON -> Base64 Image Embed.
907
+
908
+ Args:
909
+ input_pdf_path: Path to the input PDF file.
910
+ layoutlmv3_model_path: Path to the saved LayoutLMv3-CRF PyTorch model checkpoint.
911
+
912
+ Returns:
913
+ The final structured JSON data as a Python list of dictionaries, or None on failure.
914
+ """
915
+ if not os.path.exists(input_pdf_path):
916
+ print(f"❌ FATAL ERROR: Input PDF not found at {input_pdf_path}.")
917
+ return None
918
+ if not os.path.exists(layoutlmv3_model_path):
919
+ print(f"❌ FATAL ERROR: LayoutLMv3 Model checkpoint not found at {layoutlmv3_model_path}.")
920
+ return None
921
+ if not os.path.exists(WEIGHTS_PATH):
922
+ print(f"❌ FATAL ERROR: YOLO Model weights not found at {WEIGHTS_PATH}. Update WEIGHTS_PATH in the script.")
923
+ return None
924
+
925
+ print("\n" + "#" * 80)
926
+ print("### STARTING FULL DOCUMENT ANALYSIS PIPELINE ###")
927
+ print("#" * 80)
928
+
929
+ # --- Setup Temporary Directories ---
930
+ # Using tempfile module is best practice, but for simplicity we stick to the local setup
931
+ pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0]
932
+ temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}")
933
+ os.makedirs(temp_pipeline_dir, exist_ok=True)
934
+
935
+ # Define intermediate file paths inside the temp directory
936
+ preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json")
937
+ raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json")
938
+ structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json")
939
+
940
+ # Column Detection Parameters
941
+ column_params = {
942
+ 'top_margin_percent': 0.10, 'bottom_margin_percent': 0.10, 'cluster_prominence': 0.70,
943
+ 'cluster_bin_size': 5, 'cluster_smoothing': 2, 'cluster_threshold_percentile': 30,
944
+ 'cluster_min_width': 25,
945
+ }
946
+
947
+ final_result = None
948
+
949
+ try:
950
+ # --- A. PHASE 1: YOLO/OCR PREPROCESSING ---
951
+ # Saves figure/equation images to FIGURE_EXTRACTION_DIR and OCR data to preprocessed_json_path
952
+ preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path)
953
+
954
+ if not preprocessed_json_path_out:
955
+ print("Pipeline aborted after Phase 1.")
956
+ return None
957
+
958
+ # --- B. PHASE 2: LAYOUTLMV3 INFERENCE (Raw Output) ---
959
+ page_raw_predictions_list = run_inference_and_get_raw_words(
960
+ input_pdf_path,
961
+ layoutlmv3_model_path,
962
+ preprocessed_json_path_out,
963
+ column_detection_params=column_params
964
+ )
965
+
966
+ if not page_raw_predictions_list:
967
+ print("Pipeline aborted: No raw predictions generated in Phase 2.")
968
+ return None
969
+
970
+ # Save raw predictions (required input for Phase 3 via file path)
971
+ with open(raw_output_path, 'w', encoding='utf-8') as f:
972
+ json.dump(page_raw_predictions_list, f, indent=4)
973
+
974
+ # --- C. PHASE 3: BIO TO STRUCTURED JSON DECODING ---
975
+ structured_data_list = convert_bio_to_structured_json_relaxed(
976
+ raw_output_path,
977
+ structured_intermediate_output_path
978
+ )
979
+
980
+ if not structured_data_list:
981
+ print("Pipeline aborted: Failed to convert BIO tags to structured data in Phase 3.")
982
+ return None
983
+
984
+ # --- D. PHASE 4: IMAGE EMBEDDING (Base64) ---
985
+ final_result = embed_images_as_base64_in_memory(
986
+ structured_data_list,
987
+ FIGURE_EXTRACTION_DIR
988
+ )
989
+
990
+ except Exception as e:
991
+ print(f"❌ FATAL ERROR during pipeline execution: {e}", file=sys.stderr)
992
+ return None
993
+
994
+ finally:
995
+ # --- E. Cleanup ---
996
+ # Note: In a real environment, you'd be careful about FIGURE_EXTRACTION_DIR,
997
+ # but the temporary PDF images and pipeline files should be cleaned up.
998
+ try:
999
+ # Clean up temp images from Phase 1
1000
+ for f in glob.glob(os.path.join(TEMP_IMAGE_DIR, '*')): os.remove(f)
1001
+ os.rmdir(TEMP_IMAGE_DIR)
1002
+ except Exception:
1003
+ pass # Ignore cleanup errors
1004
+
1005
+ try:
1006
+ # Clean up temporary pipeline directory
1007
+ for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): os.remove(f)
1008
+ os.rmdir(temp_pipeline_dir)
1009
+ except Exception:
1010
+ pass
1011
+
1012
+ # --- F. FINAL STATUS ---
1013
+ print("\n" + "#" * 80)
1014
+ print("### FULL PIPELINE EXECUTION COMPLETE ###")
1015
+ print(f"Returning final structured data for {pdf_name}.")
1016
+ print("#" * 80)
1017
+
1018
+ return final_result
1019
+
1020
+
1021
+ if __name__ == "__main__":
1022
+ parser = argparse.ArgumentParser(
1023
+ description="Complete Document Analysis Pipeline (YOLO/OCR -> LayoutLMv3 -> Structured JSON -> Base64 Image Embed).")
1024
+ parser.add_argument("--input_pdf", type=str, required=True,
1025
+ help="Path to the input PDF file for analysis.")
1026
+ parser.add_argument("--layoutlmv3_model_path", type=str,
1027
+ default=DEFAULT_LAYOUTLMV3_MODEL_PATH,
1028
+ help="Path to the saved LayoutLMv3-CRF PyTorch model checkpoint.")
1029
+
1030
+ args = parser.parse_args()
1031
+
1032
+ # --- Call the main function ---
1033
+ final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path)
1034
+
1035
+ if final_json_data:
1036
+ # Example of what to do with the returned data: Save it to a file
1037
+ output_file_name = os.path.splitext(os.path.basename(args.input_pdf))[0] + "_final_output_embedded.json"
1038
+
1039
+ # Determine where to save the final output (e.g., current directory)
1040
+ final_output_path = os.path.abspath(output_file_name)
1041
+
1042
+ with open(final_output_path, 'w', encoding='utf-8') as f:
1043
+ json.dump(final_json_data, f, indent=2, ensure_ascii=False)
1044
+
1045
+ print(f"\n✅ Final structured data successfully returned and saved to: {final_output_path}")