|
|
|
""" |
|
MonkeyOCR 3B Gradio App for MacBook M4 Pro with MPS Acceleration |
|
Optimized for local deployment with Apple Silicon GPU acceleration |
|
""" |
|
|
|
import os |
|
import sys |
|
import tempfile |
|
import shutil |
|
from pathlib import Path |
|
import base64 |
|
import re |
|
import uuid |
|
import subprocess |
|
from typing import Optional, Tuple |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from pdf2image import convert_from_path |
|
from loguru import logger |
|
|
|
|
|
from torch_patch import patch_torch_load |
|
patch_torch_load() |
|
|
|
|
|
sys.path.append("./MonkeyOCR") |
|
|
|
try: |
|
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader |
|
from magic_pdf.data.dataset import PymuDocDataset, ImageDataset |
|
from magic_pdf.model.doc_analyze_by_custom_model_llm import doc_analyze_llm |
|
from magic_pdf.model.custom_model import MonkeyOCR |
|
except ImportError as e: |
|
logger.error(f"Failed to import MonkeyOCR modules: {e}") |
|
logger.info("Please ensure MonkeyOCR is properly installed") |
|
sys.exit(1) |
|
|
|
|
|
model_instance = None |
|
|
|
def initialize_model(config_path: str = "model_configs_mps.yaml") -> MonkeyOCR: |
|
"""Initialize MonkeyOCR model with MPS optimization""" |
|
global model_instance |
|
|
|
if model_instance is None: |
|
logger.info("Initializing MonkeyOCR model with MPS acceleration...") |
|
|
|
|
|
if not torch.backends.mps.is_available(): |
|
logger.warning("MPS not available, falling back to CPU") |
|
|
|
import yaml |
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f) |
|
config['device'] = 'cpu' |
|
with open(config_path, 'w') as f: |
|
yaml.dump(config, f) |
|
else: |
|
logger.info("MPS is available and will be used for acceleration") |
|
|
|
|
|
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' |
|
|
|
try: |
|
model_instance = MonkeyOCR(config_path) |
|
logger.info("Model initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize model: {e}") |
|
raise |
|
|
|
return model_instance |
|
|
|
def render_latex_table_to_image(latex_content: str, temp_dir: str) -> str: |
|
"""Render LaTeX table to image and return HTML img tag""" |
|
try: |
|
|
|
pattern = r"(\\begin\{tabular\}.*?\\end\{tabular\})" |
|
matches = re.findall(pattern, latex_content, re.DOTALL) |
|
|
|
if matches: |
|
table_content = matches[0] |
|
elif '\\begin{tabular}' in latex_content: |
|
if '\\end{tabular}' not in latex_content: |
|
table_content = latex_content + '\n\\end{tabular}' |
|
else: |
|
table_content = latex_content |
|
else: |
|
return latex_content |
|
|
|
|
|
full_latex = r""" |
|
\documentclass{article} |
|
\usepackage[utf8]{inputenc} |
|
\usepackage{booktabs} |
|
\usepackage{bm} |
|
\usepackage{multirow} |
|
\usepackage{array} |
|
\usepackage{colortbl} |
|
\usepackage[table]{xcolor} |
|
\usepackage{amsmath} |
|
\usepackage{amssymb} |
|
\usepackage{graphicx} |
|
\usepackage{geometry} |
|
\usepackage{makecell} |
|
\usepackage[active,tightpage]{preview} |
|
\PreviewEnvironment{tabular} |
|
\begin{document} |
|
""" + table_content + r""" |
|
\end{document} |
|
""" |
|
|
|
|
|
unique_id = str(uuid.uuid4())[:8] |
|
tex_path = os.path.join(temp_dir, f"table_{unique_id}.tex") |
|
pdf_path = os.path.join(temp_dir, f"table_{unique_id}.pdf") |
|
png_path = os.path.join(temp_dir, f"table_{unique_id}.png") |
|
|
|
|
|
with open(tex_path, "w", encoding="utf-8") as f: |
|
f.write(full_latex) |
|
|
|
|
|
result = subprocess.run( |
|
["pdflatex", "-interaction=nonstopmode", "-output-directory", temp_dir, tex_path], |
|
timeout=20, |
|
capture_output=True, |
|
text=True |
|
) |
|
|
|
if result.returncode != 0 or not os.path.exists(pdf_path): |
|
logger.warning("LaTeX compilation failed, returning original content") |
|
return f"<pre>{latex_content}</pre>" |
|
|
|
|
|
images = convert_from_path(pdf_path, dpi=300) |
|
images[0].save(png_path, "PNG") |
|
|
|
|
|
with open(png_path, "rb") as f: |
|
img_data = f.read() |
|
img_base64 = base64.b64encode(img_data).decode("utf-8") |
|
|
|
|
|
for file_path in [tex_path, pdf_path, png_path]: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
return f'<img src="data:image/png;base64,{img_base64}" style="max-width:100%;height:auto;">' |
|
|
|
except Exception as e: |
|
logger.warning(f"LaTeX rendering error: {e}") |
|
return f"<pre>{latex_content}</pre>" |
|
|
|
def process_document(file_path: str) -> Tuple[str, str]: |
|
"""Process document and return markdown content and layout PDF path""" |
|
if not file_path: |
|
return "", "" |
|
|
|
try: |
|
model = initialize_model() |
|
|
|
parent_path = os.path.dirname(file_path) |
|
full_name = os.path.basename(file_path) |
|
name = '.'.join(full_name.split(".")[:-1]) |
|
|
|
|
|
local_image_dir = os.path.join(parent_path, "markdown", "images") |
|
local_md_dir = os.path.join(parent_path, "markdown") |
|
os.makedirs(local_image_dir, exist_ok=True) |
|
os.makedirs(local_md_dir, exist_ok=True) |
|
|
|
image_dir = os.path.basename(local_image_dir) |
|
image_writer = FileBasedDataWriter(local_image_dir) |
|
md_writer = FileBasedDataWriter(local_md_dir) |
|
reader = FileBasedDataReader(parent_path) |
|
|
|
|
|
data_bytes = reader.read(full_name) |
|
|
|
|
|
if full_name.split(".")[-1].lower() in ['jpg', 'jpeg', 'png']: |
|
ds = ImageDataset(data_bytes) |
|
else: |
|
ds = PymuDocDataset(data_bytes) |
|
|
|
|
|
logger.info("Processing document with MonkeyOCR...") |
|
|
|
import threading |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError |
|
|
|
def process_with_model(): |
|
overall_start_time = time.time() |
|
|
|
|
|
analysis_start_time = time.time() |
|
logger.info("Starting document analysis...") |
|
infer_result = ds.apply(doc_analyze_llm, MonkeyOCR_model=model) |
|
logger.info(f"PROFILE: Document analysis (doc_analyze_llm) took {time.time() - analysis_start_time:.2f}s") |
|
|
|
|
|
ocr_start_time = time.time() |
|
logger.info("Starting OCR and layout processing...") |
|
pipe_result = infer_result.pipe_ocr_mode(image_writer, MonkeyOCR_model=model) |
|
logger.info(f"PROFILE: OCR/Layout (pipe_ocr_mode) took {time.time() - ocr_start_time:.2f}s") |
|
|
|
logger.info(f"PROFILE: Total model processing took {time.time() - overall_start_time:.2f}s") |
|
return infer_result, pipe_result |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
future = executor.submit(process_with_model) |
|
try: |
|
infer_result, pipe_result = future.result(timeout=300) |
|
except FutureTimeoutError: |
|
logger.error("Processing timed out after 5 minutes") |
|
raise TimeoutError("Document processing timed out. Please try with a smaller document or simpler layout.") |
|
|
|
|
|
layout_pdf_path = os.path.join(parent_path, f"{name}_layout.pdf") |
|
pipe_result.draw_layout(layout_pdf_path) |
|
|
|
|
|
pipe_result.dump_md(md_writer, f"{name}.md", image_dir) |
|
md_content_ori = FileBasedDataReader(local_md_dir).read(f"{name}.md").decode("utf-8") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
try: |
|
|
|
def replace_html_latex_table(match): |
|
html_content = match.group(1) |
|
if '\\begin{tabular}' in html_content: |
|
return render_latex_table_to_image(html_content, temp_dir) |
|
else: |
|
return match.group(0) |
|
|
|
md_content = re.sub(r'<html>(.*?)</html>', replace_html_latex_table, md_content_ori, flags=re.DOTALL) |
|
|
|
|
|
def replace_image_with_base64(match): |
|
img_path = match.group(1) |
|
if not os.path.isabs(img_path): |
|
full_img_path = os.path.join(local_md_dir, img_path) |
|
else: |
|
full_img_path = img_path |
|
|
|
try: |
|
if os.path.exists(full_img_path): |
|
with open(full_img_path, "rb") as f: |
|
img_data = f.read() |
|
img_base64 = base64.b64encode(img_data).decode("utf-8") |
|
ext = os.path.splitext(full_img_path)[1].lower() |
|
mime_type = "image/jpeg" if ext in ['.jpg', '.jpeg'] else f"image/{ext[1:]}" |
|
return f'<img src="data:{mime_type};base64,{img_base64}" style="max-width:100%;height:auto;">' |
|
else: |
|
return match.group(0) |
|
except Exception: |
|
return match.group(0) |
|
|
|
md_content = re.sub(r'!\[.*?\]\(([^)]+)\)', replace_image_with_base64, md_content) |
|
|
|
finally: |
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
logger.info("Document processing completed successfully") |
|
return md_content, layout_pdf_path |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing document: {e}") |
|
return f"Error processing document: {str(e)}", "" |
|
|
|
def parse_document(file) -> Tuple[str, Optional[str]]: |
|
"""Parse uploaded document and return results""" |
|
if file is None: |
|
return "Please upload a document first.", None |
|
|
|
try: |
|
|
|
markdown_content, layout_pdf_path = process_document(file.name) |
|
|
|
if not markdown_content: |
|
return "Failed to process document.", None |
|
|
|
return markdown_content, layout_pdf_path if os.path.exists(layout_pdf_path) else None |
|
|
|
except Exception as e: |
|
logger.error(f"Error in parse_document: {e}") |
|
return f"Error: {str(e)}", None |
|
|
|
def create_gradio_interface(): |
|
"""Create and configure Gradio interface""" |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
max-width: 1200px !important; |
|
} |
|
.markdown-content { |
|
max-height: 600px; |
|
overflow-y: auto; |
|
border: 1px solid #ddd; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
""" |
|
|
|
with gr.Blocks( |
|
title="MonkeyOCR 3B - Local MPS Demo", |
|
css=css, |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
|
|
gr.Markdown(""" |
|
# π΅ MonkeyOCR 3B - Local Demo (Apple Silicon MPS) |
|
|
|
**Optimized for MacBook M4 Pro with 48GB RAM** |
|
|
|
Upload a PDF or image document to extract structured content with state-of-the-art accuracy. |
|
The model runs locally using Apple's Metal Performance Shaders for GPU acceleration. |
|
|
|
**Supported formats:** PDF, PNG, JPG, JPEG |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_input = gr.File( |
|
label="π Upload Document", |
|
file_types=[".pdf", ".png", ".jpg", ".jpeg"], |
|
type="filepath" |
|
) |
|
|
|
parse_btn = gr.Button( |
|
"π Parse Document", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
gr.Markdown(""" |
|
**Tips:** |
|
- Larger documents may take a few minutes to process |
|
- The model excels at formulas, tables, and complex layouts |
|
- Processing speed: ~0.84 pages/second on M4 Pro |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
markdown_output = gr.Markdown( |
|
label="π Extracted Content", |
|
elem_classes=["markdown-content"] |
|
) |
|
|
|
layout_pdf_output = gr.File( |
|
label="π Layout Analysis (PDF)", |
|
visible=False |
|
) |
|
|
|
|
|
parse_btn.click( |
|
fn=parse_document, |
|
inputs=[file_input], |
|
outputs=[markdown_output, layout_pdf_output], |
|
show_progress=True |
|
) |
|
|
|
|
|
def show_layout_pdf(pdf_path): |
|
if pdf_path: |
|
return gr.update(visible=True, value=pdf_path) |
|
return gr.update(visible=False) |
|
|
|
layout_pdf_output.change( |
|
fn=show_layout_pdf, |
|
inputs=[layout_pdf_output], |
|
outputs=[layout_pdf_output] |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main function to run the Gradio app""" |
|
logger.info("Starting MonkeyOCR 3B Gradio App...") |
|
|
|
|
|
if not torch.backends.mps.is_available(): |
|
logger.warning("MPS not available. The app will run on CPU which may be slower.") |
|
else: |
|
logger.info("MPS is available. GPU acceleration enabled.") |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7861, |
|
share=False, |
|
show_error=True, |
|
quiet=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |