trevorpfiz
fix: unexpected keyword argument 'file_name'
4aa9a45
import gc
import hashlib
import json
import math
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import fitz # PyMuPDF
import gradio as gr
import requests
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoProcessor
from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS
from .utils.prompts import dict_promptmode_to_prompt
# ============================
# Constants and configuration
# ============================
APP_TITLE = "PreviewSpace — VLM Playground"
TMP_DIR = "/tmp/previewspace"
MODELS_DIR = os.path.join(TMP_DIR, "models")
DOTS_REPO_ID = "rednote-hilab/dots.ocr"
DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr")
DEFAULT_PROMPT = dict_promptmode_to_prompt.get(
"prompt_layout_all_en",
(
"Please output the layout information from the PDF page image. For each element, return: "
'bbox: [x1, y1, x2, y2], category from {"title","header","paragraph","table","figure","footnote"}, and text. '
'Return JSON: {"elements": [{"bbox": [..], "category": "..", "text": ".."}], "page": <number>}'
),
)
os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
# ===========
# Utilities
# ===========
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> Tuple[int, int]:
if max(height, width) / min(height, width) > 200:
raise ValueError("absolute aspect ratio must be smaller than 200")
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = round_by_factor(height / beta, factor)
w_bar = round_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = round_by_factor(height * beta, factor)
w_bar = round_by_factor(width * beta, factor)
return int(h_bar), int(w_bar)
def fetch_image(
image_input: Any,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> Image.Image:
if isinstance(image_input, str):
if image_input.startswith(("http://", "https://")):
response = requests.get(image_input, timeout=60)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError(f"Invalid image input type: {type(image_input)}")
if min_pixels is not None or max_pixels is not None:
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
new_h, new_w = smart_resize(
image.height,
image.width,
factor=IMAGE_FACTOR,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((new_w, new_h), Image.LANCZOS)
return image
def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
images: List[Image.Image] = []
pdf_document = fitz.open(pdf_path)
try:
for page_idx in range(len(pdf_document)):
page = pdf_document.load_page(page_idx)
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
img_data = pix.tobytes("ppm")
image = Image.open(BytesIO(img_data)).convert("RGB")
images.append(image)
finally:
pdf_document.close()
return images
def file_checksum(path: str, chunk_size: int = 1 << 20) -> str:
hasher = hashlib.sha256()
with open(path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
img = image.copy()
draw = ImageDraw.Draw(img)
colors = {
"Caption": "#FF6B6B",
"Footnote": "#4ECDC4",
"Formula": "#45B7D1",
"List-item": "#96CEB4",
"Page-footer": "#FFEAA7",
"Page-header": "#DDA0DD",
"Picture": "#FFD93D",
"Section-header": "#6C5CE7",
"Table": "#FD79A8",
"Text": "#74B9FF",
"Title": "#E17055",
}
try:
try:
font = ImageFont.truetype(
"/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12
)
except Exception:
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12
)
except Exception:
font = ImageFont.load_default()
for item in layout_data:
bbox = item.get("bbox")
category = item.get("category")
if not bbox or not category:
continue
color = colors.get(category, "#000000")
draw.rectangle(bbox, outline=color, width=2)
label = str(category)
label_bbox = draw.textbbox((0, 0), label, font=font)
label_w = label_bbox[2] - label_bbox[0]
label_h = label_bbox[3] - label_bbox[1]
x1, y1 = int(bbox[0]), int(bbox[1])
lx = x1
ly = max(0, y1 - label_h - 2)
draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color)
draw.text((lx + 2, ly + 1), label, fill="white", font=font)
except Exception:
pass
return img
def is_arabic_text(text: str) -> bool:
if not text:
return False
header_pattern = r"^#{1,6}\s+(.+)$"
paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$"
content_lines: List[str] = []
for line in text.split("\n"):
s = line.strip()
if not s:
continue
m = re.match(header_pattern, s)
if m:
content_lines.append(m.group(1))
continue
if re.match(paragraph_pattern, s):
content_lines.append(s)
if not content_lines:
return False
combined = " ".join(content_lines)
arabic = 0
total = 0
for ch in combined:
if ch.isalpha():
total += 1
if (
("\u0600" <= ch <= "\u06ff")
or ("\u0750" <= ch <= "\u077f")
or ("\u08a0" <= ch <= "\u08ff")
):
arabic += 1
if total == 0:
return False
return (arabic / total) > 0.5
def extract_json(text: str) -> Optional[Dict[str, Any]]:
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Try to extract JSON block
brace_start = text.find("{")
brace_end = text.rfind("}")
if 0 <= brace_start < brace_end:
snippet = text[brace_start : brace_end + 1]
try:
return json.loads(snippet)
except Exception:
pass
fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text)
for block in fenced:
try:
return json.loads(block)
except Exception:
continue
return None
def layoutjson2md(
image: Image.Image, layout_data: List[Dict], text_key: str = "text"
) -> str:
lines: List[str] = []
try:
items = sorted(
layout_data,
key=lambda x: (
x.get("bbox", [0, 0, 0, 0])[1],
x.get("bbox", [0, 0, 0, 0])[0],
),
)
for item in items:
category = item.get("category", "")
text = item.get(text_key, "")
if category == "Title" and text:
lines.append(f"# {text}\n")
elif category == "Section-header" and text:
lines.append(f"## {text}\n")
elif category == "List-item" and text:
lines.append(f"- {text}\n")
elif category == "Table" and text:
if text.strip().startswith("<"):
lines.append(text + "\n")
else:
lines.append(f"**Table:** {text}\n")
elif category == "Formula" and text:
if text.strip().startswith("$") or "\\" in text:
lines.append(f"$$\n{text}\n$$\n")
else:
lines.append(f"**Formula:** {text}\n")
elif category == "Caption" and text:
lines.append(f"*{text}*\n")
elif category in ["Page-header", "Page-footer"]:
continue
elif category == "Picture":
# Skip embedding image fragments in markdown for now
continue
elif text:
lines.append(f"{text}\n")
lines.append("")
except Exception:
return json.dumps(layout_data, ensure_ascii=False)
return "\n".join(lines)
# =====================
# Model initialization
# =====================
model: Optional[AutoModelForCausalLM] = None
processor: Optional[AutoProcessor] = None
device = (
"cuda"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)
def get_torch_dtype() -> torch.dtype:
if device == "cuda":
return torch.bfloat16
if device == "mps":
return torch.float16
return torch.float32
def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]:
global model, processor
if model is not None and processor is not None:
return model, processor
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
snapshot_download(
repo_id=DOTS_REPO_ID,
local_dir=DOTS_LOCAL_DIR,
local_dir_use_symlinks=False,
)
dtype = get_torch_dtype()
model = AutoModelForCausalLM.from_pretrained(
DOTS_LOCAL_DIR,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
)
proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True)
processor = proc
return model, processor
def run_inference(
image: Image.Image, prompt_text: str, max_new_tokens: int = 24000
) -> str:
mdl, proc = ensure_model_loaded()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text},
],
}
]
text = proc.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = proc(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.no_grad():
generated_ids = mdl.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False,
temperature=0.1,
)
trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def process_single_image(
image: Image.Image,
prompt_text: str,
min_pixels: Optional[int],
max_pixels: Optional[int],
max_new_tokens: int,
) -> Dict[str, Any]:
img = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens)
result: Dict[str, Any] = {
"original_image": img,
"processed_image": img,
"raw_output": raw,
"layout_result": None,
"markdown": None,
}
data = extract_json(raw)
if isinstance(data, dict):
result["layout_result"] = data
items = data.get("elements", data.get("elements_list", data.get("content", [])))
if isinstance(items, list):
result["processed_image"] = draw_layout_on_image(img, items)
result["markdown"] = layoutjson2md(img, items)
if result["markdown"] is None:
result["markdown"] = raw
return result
# =================
# Gradio Interface
# =================
def create_blocks_app():
css = """
.main-container { max-width: 1500px; margin: 0 auto; }
.header-text { text-align: center; color: #1f2937; margin-bottom: 12px; }
.page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; }
.process-button { border: none !important; color: white !important; font-weight: 700 !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo:
# App state
doc_state = gr.State(
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
}
)
cache_state = gr.State({}) # (checksum, page, prompt_hash) -> result
gr.HTML(
"""
<div class=\"header-text\">
<h2>VLM Playground — dots.ocr</h2>
<p>Upload a PDF or image, preview pages, and parse with a layout-extraction prompt.</p>
</div>
"""
)
with gr.Row(elem_classes=["main-container"]):
# Left: upload + controls
with gr.Column(scale=4):
file_input = gr.File(
label="Upload PDF or Image",
file_types=[
".pdf",
".png",
".jpg",
".jpeg",
".bmp",
".tiff",
".webp",
],
type="filepath",
)
with gr.Group():
template = gr.Dropdown(
label="Prompt Template",
choices=["Layout Extraction"],
value="Layout Extraction",
)
prompt_text = gr.Textbox(
label="Current Prompt",
value=DEFAULT_PROMPT,
lines=6,
)
with gr.Row():
parse_button = gr.Button(
"Parse", variant="primary", elem_classes=["process-button"]
)
clear_button = gr.Button("Clear")
with gr.Accordion("Advanced", open=False):
max_new_tokens = gr.Slider(
minimum=512,
maximum=32000,
value=24000,
step=256,
label="Max new tokens",
)
min_pixels_in = gr.Number(value=MIN_PIXELS, label="Min pixels")
max_pixels_in = gr.Number(value=MAX_PIXELS, label="Max pixels")
page_range = gr.Textbox(
label="Page selection",
placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)",
)
# Center: page preview + nav
with gr.Column(scale=5):
preview_image = gr.Image(label="Page Preview", type="pil", height=520)
with gr.Row():
prev_btn = gr.Button("◀ Prev")
page_info = gr.HTML('<div class="page-info">No file</div>')
next_btn = gr.Button("Next ▶")
with gr.Row():
page_jump = gr.Number(value=1, label="Page #", precision=0)
jump_btn = gr.Button("Go")
# Right: results
with gr.Column(scale=6):
with gr.Tabs():
with gr.Tab("Markdown Render"):
md_render = gr.Markdown(
value="Upload and parse to view results", height=520
)
with gr.Tab("Raw Markdown"):
md_raw = gr.Textbox(value="", lines=20)
with gr.Tab("Current Page JSON"):
json_view = gr.JSON(value=None)
with gr.Tab("Processed Image"):
processed_view = gr.Image(type="pil", height=520)
with gr.Row():
download_jsonl = gr.DownloadButton(label="Download JSONL")
download_markdown = gr.DownloadButton(label="Download Markdown")
# ===== Handlers =====
def on_template_change(choice: str) -> str:
return DEFAULT_PROMPT
def on_file_change(path: Optional[str]):
if not path or not os.path.exists(path):
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
)
checksum = file_checksum(path)
ext = os.path.splitext(path)[1].lower()
if ext == ".pdf":
images = load_images_from_pdf(path)
state = {
"images": images,
"current_page": 0,
"total_pages": len(images),
"file_type": "pdf",
"checksum": checksum,
"results": [None] * len(images),
"parsed": False,
}
return (
state,
images[0] if images else None,
f'<div class="page-info">Page 1 / {len(images)}</div>',
)
else:
image = Image.open(path).convert("RGB")
state = {
"images": [image],
"current_page": 0,
"total_pages": 1,
"file_type": "image",
"checksum": checksum,
"results": [None],
"parsed": False,
}
return state, image, '<div class="page-info">Page 1 / 1</div>'
def nav_page(state: Dict[str, Any], direction: str):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
if direction == "prev":
state["current_page"] = max(0, state["current_page"] - 1)
elif direction == "next":
state["current_page"] = min(
state["total_pages"] - 1, state["current_page"] + 1
)
idx = state["current_page"]
img = state["images"][idx]
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
result = (
state["results"][idx]
if state.get("parsed") and idx < len(state["results"])
else None
)
md = result.get("markdown") if result else "Page not processed yet"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = result.get("processed_image") if result else None
js = result.get("layout_result") if result else None
return state, img, info, md_out, md_raw_text, proc_img, js
def jump_to_page(state: Dict[str, Any], page_num: Any):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
try:
n = int(page_num)
except Exception:
n = 1
n = max(1, min(state["total_pages"], n))
state["current_page"] = n - 1
return nav_page(state, direction="stay")
def parse_pages(
state: Dict[str, Any],
prompt: str,
max_tokens: int,
min_pix: Optional[float],
max_pix: Optional[float],
selection: Optional[str],
):
if not state.get("images"):
return state, None, "No file", "No content", "", None, None
# Determine pages to process
indices: List[int] = []
if not selection or selection.strip() == "":
indices = [state["current_page"]]
elif selection.strip().lower() == "all":
indices = list(range(state["total_pages"]))
else:
# parse like 1-3,5
parts = [p.strip() for p in selection.split(",") if p.strip()]
for p in parts:
if "-" in p:
a, b = p.split("-", 1)
try:
a_i = max(1, int(a))
b_i = min(state["total_pages"], int(b))
for i in range(a_i - 1, b_i):
indices.append(i)
except Exception:
continue
else:
try:
i = max(1, min(state["total_pages"], int(p)))
indices.append(i - 1)
except Exception:
continue
indices = sorted(
set([i for i in indices if 0 <= i < state["total_pages"]])
)
# Process sequentially for stability
results = state.get("results") or [None] * state["total_pages"]
for i in indices:
img = state["images"][i]
prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16]
cache_key = (
state["checksum"],
i,
prompt_hash,
int(min_pix or 0),
int(max_pix or 0),
int(max_tokens),
)
cached = cache_state.value.get(cache_key)
if cached:
results[i] = cached
continue
res = process_single_image(
img,
prompt_text=prompt,
min_pixels=int(min_pix) if min_pix else None,
max_pixels=int(max_pix) if max_pix else None,
max_new_tokens=int(max_tokens),
)
results[i] = res
cache_state.value[cache_key] = res
state["results"] = results
state["parsed"] = True
# Return current page outputs
idx = state["current_page"]
curr = results[idx]
md = curr.get("markdown") if curr else "No content"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = curr.get("processed_image") if curr else None
js = curr.get("layout_result") if curr else None
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
prev = state["images"][idx]
return state, prev, info, md_out, md_raw_text, proc_img, js
def clear_all():
gc.collect()
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
"Upload and parse to view results",
"",
None,
None,
)
def download_current_jsonl(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
lines: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("layout_result") is not None:
obj = {"page": i + 1, "layout": res["layout_result"]}
lines.append(json.dumps(obj, ensure_ascii=False))
content = "\n".join(lines) if lines else ""
out_path = os.path.join(TMP_DIR, "results.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
def download_current_markdown(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
chunks: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("markdown"):
chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
content = "\n\n---\n\n".join(chunks) if chunks else ""
out_path = os.path.join(TMP_DIR, "results.md")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
# Wire events
template.change(on_template_change, inputs=[template], outputs=[prompt_text])
file_input.change(
on_file_change,
inputs=[file_input],
outputs=[doc_state, preview_image, page_info],
)
prev_btn.click(
lambda s: nav_page(s, "prev"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
next_btn.click(
lambda s: nav_page(s, "next"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
jump_btn.click(
jump_to_page,
inputs=[doc_state, page_jump],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
parse_button.click(
parse_pages,
inputs=[
doc_state,
prompt_text,
max_new_tokens,
min_pixels_in,
max_pixels_in,
page_range,
],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
clear_button.click(
clear_all,
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
download_jsonl.click(
download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl]
)
download_markdown.click(
download_current_markdown, inputs=[doc_state], outputs=[download_markdown]
)
return demo