feeedback / app.py
heerjtdev's picture
Update app.py
46d219d verified
import fitz # PyMuPDF
import numpy as np
import cv2
import torch
import torch.serialization
import os
import time
from typing import Optional, Tuple, List, Dict, Any
from ultralytics import YOLO
import logging
import gradio as gr
import shutil
import tempfile
import io
# ============================================================================
# --- Global Patches and Setup ---
# ============================================================================
# Patch torch.load to prevent weights_only error with older models
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
logging.basicConfig(level=logging.WARNING)
# ============================================================================
# --- CONFIGURATION AND CONSTANTS ---
# ============================================================================
WEIGHTS_PATH = 'best.pt'
SCALE_FACTOR = 2.0
# Detection parameters
CONF_THRESHOLD = 0.2
TARGET_CLASSES = ['figure', 'equation']
IOU_MERGE_THRESHOLD = 0.4
IOA_SUPPRESSION_THRESHOLD = 0.7
# Global counters (Reset per run)
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# ============================================================================
# --- BOX COMBINATION LOGIC (Retained for detection accuracy) ---
# ============================================================================
def calculate_iou(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
union_area = float(box_a_area + box_b_area - intersection_area)
return intersection_area / union_area if union_area > 0 else 0
def filter_nested_boxes(detections, ioa_threshold=0.80):
if not detections: return []
for d in detections:
x1, y1, x2, y2 = d['coords']
d['area'] = (x2 - x1) * (y2 - y1)
detections.sort(key=lambda x: x['area'], reverse=True)
keep_indices = []
is_suppressed = [False] * len(detections)
for i in range(len(detections)):
if is_suppressed[i]: continue
keep_indices.append(i)
box_a = detections[i]['coords']
for j in range(i + 1, len(detections)):
if is_suppressed[j]: continue
box_b = detections[j]['coords']
x_left = max(box_a[0], box_b[0])
y_top = max(box_a[1], box_b[1])
x_right = min(box_a[2], box_b[2])
y_bottom = min(box_a[3], box_b[3])
intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top)
area_b = detections[j]['area']
if area_b > 0 and intersection / area_b > ioa_threshold:
is_suppressed[j] = True
return [detections[i] for i in keep_indices]
def merge_overlapping_boxes(detections, iou_threshold):
if not detections: return []
detections.sort(key=lambda d: d['conf'], reverse=True)
merged_detections = []
is_merged = [False] * len(detections)
for i in range(len(detections)):
if is_merged[i]: continue
current_box = detections[i]['coords']
current_class = detections[i]['class']
merged_x1, merged_y1, merged_x2, merged_y2 = current_box
for j in range(i + 1, len(detections)):
if is_merged[j] or detections[j]['class'] != current_class: continue
other_box = detections[j]['coords']
iou = calculate_iou(current_box, other_box)
if iou > iou_threshold:
merged_x1 = min(merged_x1, other_box[0])
merged_y1 = min(merged_y1, other_box[1])
merged_x2 = max(merged_x2, other_box[2])
merged_y2 = max(merged_y2, other_box[3])
is_merged[j] = True
merged_detections.append({
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
})
return merged_detections
# ============================================================================
# --- UTILITY FUNCTIONS ---
# ============================================================================
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
"""Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO."""
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
(pix.h, pix.w, pix.n)
)
if pix.n == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
elif pix.n == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return img
def run_yolo_detection_and_count(
image: np.ndarray, model: YOLO, page_num: int
) -> Tuple[int, int]:
"""
Runs YOLO inference, applies NMS/filtering, and updates global counters.
Returns page counts only.
"""
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
yolo_detections = []
page_equations = 0
page_figures = 0
try:
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
if results and results[0].boxes:
for box in results[0].boxes.data.tolist():
x1, y1, x2, y2, conf, cls_id = box
cls_name = model.names[int(cls_id)]
if cls_name in TARGET_CLASSES:
yolo_detections.append({
'coords': (x1, y1, x2, y2),
'class': cls_name,
'conf': conf
})
except Exception as e:
logging.error(f"YOLO inference failed on page {page_num}: {e}")
return 0, 0
# Apply NMS/Merging/Filtering
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
# Update Global Counters
for det in final_detections:
if det['class'] == 'figure':
GLOBAL_FIGURE_COUNT += 1
page_figures += 1
elif det['class'] == 'equation':
GLOBAL_EQUATION_COUNT += 1
page_equations += 1
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}")
return page_equations, page_figures
# ============================================================================
# --- MAIN DOCUMENT PROCESSING FUNCTION (Fixed for JSON serialization) ---
# ============================================================================
# NOTE: The return signature now uses Dict[str, int] for the equation counts
def run_single_pdf_preprocessing(pdf_path: str) -> Tuple[int, int, int, str, float, Dict[str, int], List[str]]:
"""
Runs the pipeline, returns counts, report, total time, page counts dict (str keys), and empty list.
"""
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
start_time = time.time()
log_messages = []
# Dictionary to store {page_number (int): equation_count (int)}
equation_counts_per_page: Dict[int, int] = {}
# Reset globals
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# 1. Validation and Model Loading
t0 = time.time()
if not os.path.exists(pdf_path):
report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
return 0, 0, 0, report, time.time() - start_time, {}, []
try:
model = YOLO(WEIGHTS_PATH)
logging.warning(f"βœ… Loaded YOLO model from: {WEIGHTS_PATH}")
except Exception as e:
report = f"❌ ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
return 0, 0, 0, report, time.time() - start_time, {}, []
t1 = time.time()
log_messages.append(f"Model Loading Time: {t1-t0:.4f}s")
# 2. PDF Loading
t2 = time.time()
try:
doc = fitz.open(pdf_path)
total_pages = doc.page_count
logging.warning(f"βœ… Opened PDF with {doc.page_count} pages")
except Exception as e:
report = f"❌ ERROR loading PDF file: {e}"
return 0, 0, 0, report, time.time() - start_time, {}, []
t3 = time.time()
log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s")
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
# 3. Page Processing and Detection Loop
t4 = time.time()
for page_num_0_based in range(doc.page_count):
page_start_time = time.time()
fitz_page = doc.load_page(page_num_0_based)
page_num = page_num_0_based + 1
# Render page to image for YOLO
try:
pix_start = time.time()
pix = fitz_page.get_pixmap(matrix=mat)
original_img = pixmap_to_numpy(pix)
pix_time = time.time() - pix_start
except Exception as e:
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.")
continue
# Core Detection
detect_start = time.time()
page_equations, _ = run_yolo_detection_and_count(original_img, model, page_num)
detect_time = time.time() - detect_start
# Store the count in the dictionary (INT keys)
equation_counts_per_page[page_num] = page_equations
page_total_time = time.time() - page_start_time
log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)")
doc.close()
t5 = time.time()
detection_loop_time = t5 - t4
log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
# FIX APPLIED HERE: Convert integer keys to string keys for JSON serialization
equation_counts_per_page_str_keys: Dict[str, int] = {
str(k): v for k, v in equation_counts_per_page.items()
}
# 4. Final Report Generation
total_execution_time = t5 - start_time
report = (
f"βœ… **YOLO Counting Complete!**\n\n"
f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
f"**2) Total Equations Detected:** **{GLOBAL_EQUATION_COUNT}**\n"
f"**3) Total Figures Detected:** **{GLOBAL_FIGURE_COUNT}**\n"
f"---\n"
f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
f"### Detailed Step Timing\n"
f"```\n"
+ "\n".join(log_messages) +
f"\n```"
)
# Return the dictionary with string keys
return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, []
# ============================================================================
# --- GRADIO INTERFACE FUNCTION (Updated) ---
# ============================================================================
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[str]]:
"""
Gradio wrapper function to handle file upload and return results.
"""
if pdf_file is None:
# Return an empty dict with string keys
return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
pdf_path = pdf_file.name
try:
# Unpack the new return value: equation_counts_per_page (with string keys)
num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, _ = run_single_pdf_preprocessing(
pdf_path
)
# Return results (6 items now)
return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, []
except Exception as e:
error_msg = f"An unexpected error occurred: {e}"
logging.error(error_msg, exc_info=True)
# Return an empty dict on error
return "Error", "Error", "Error", error_msg, {}, []
# ============================================================================
# --- GRADIO INTERFACE DEFINITION (Updated) ---
# ============================================================================
if __name__ == "__main__":
if not os.path.exists(WEIGHTS_PATH):
logging.error(f"❌ FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.")
input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
# Outputs
output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
output_report = gr.Markdown(label="Processing Summary and Timing")
# NEW OUTPUT: JSON component for structured data
output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)")
# Gradio Gallery is retained but will receive an empty list []
output_gallery = gr.Gallery(
label="Detected Equations (Disabled for Speed)",
columns=5,
height="auto",
object_fit="contain",
allow_preview=False
)
interface = gr.Interface(
fn=gradio_process_pdf,
inputs=input_file,
# Outputs list remains the same, but the JSON component now receives string keys.
outputs=[
output_pages,
output_equations,
output_figures,
output_report,
output_page_counts,
output_gallery
],
title="πŸ“Š YOLO Counting with Per-Page Data & Timing",
description=(
"Upload a PDF to run YOLO detection. The results include total counts, a breakdown of "
"equation counts per page (in JSON format), and detailed timing."
),
)
print("\nStarting Gradio application...")
interface.launch(inbrowser=True)