VLM-playground / src /vlm_playground /preview_app_local.py
trevorpfiz
fix: unexpected keyword argument 'file_name'
4aa9a45
import gc
import types
import sys
import hashlib
import json
import math
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import fitz # PyMuPDF
import gradio as gr
import requests
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoProcessor
from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS
from .utils.prompts import dict_promptmode_to_prompt
APP_TITLE = "PreviewSpace — VLM Playground (Local)"
TMP_DIR = "/tmp/previewspace"
MODELS_DIR = os.path.join(TMP_DIR, "models")
DOTS_REPO_ID = "rednote-hilab/dots.ocr"
DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr")
LOCAL_DEFAULT_MAX_NEW_TOKENS = 2048
os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> Tuple[int, int]:
if max(height, width) / min(height, width) > 200:
raise ValueError("absolute aspect ratio must be smaller than 200")
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = round_by_factor(height / beta, factor)
w_bar = round_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = round_by_factor(height * beta, factor)
w_bar = round_by_factor(width * beta, factor)
return int(h_bar), int(w_bar)
def fetch_image(image_input: Any) -> Image.Image:
if isinstance(image_input, str):
if image_input.startswith(("http://", "https://")):
response = requests.get(image_input, timeout=60)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError(f"Invalid image input type: {type(image_input)}")
return image
def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
images: List[Image.Image] = []
pdf_document = fitz.open(pdf_path)
try:
for page_idx in range(len(pdf_document)):
page = pdf_document.load_page(page_idx)
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
img_data = pix.tobytes("ppm")
image = Image.open(BytesIO(img_data)).convert("RGB")
images.append(image)
finally:
pdf_document.close()
return images
def file_checksum(path: str, chunk_size: int = 1 << 20) -> str:
hasher = hashlib.sha256()
with open(path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
img = image.copy()
draw = ImageDraw.Draw(img)
colors = {
"Caption": "#FF6B6B",
"Footnote": "#4ECDC4",
"Formula": "#45B7D1",
"List-item": "#96CEB4",
"Page-footer": "#FFEAA7",
"Page-header": "#DDA0DD",
"Picture": "#FFD93D",
"Section-header": "#6C5CE7",
"Table": "#FD79A8",
"Text": "#74B9FF",
"Title": "#E17055",
}
try:
try:
font = ImageFont.truetype(
"/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12
)
except Exception:
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12
)
except Exception:
font = ImageFont.load_default()
for item in layout_data:
bbox = item.get("bbox")
category = item.get("category")
if not bbox or not category:
continue
color = colors.get(category, "#000000")
draw.rectangle(bbox, outline=color, width=2)
label = str(category)
label_bbox = draw.textbbox((0, 0), label, font=font)
label_w = label_bbox[2] - label_bbox[0]
label_h = label_bbox[3] - label_bbox[1]
x1, y1 = int(bbox[0]), int(bbox[1])
lx = x1
ly = max(0, y1 - label_h - 2)
draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color)
draw.text((lx + 2, ly + 1), label, fill="white", font=font)
except Exception:
pass
return img
def is_arabic_text(text: str) -> bool:
if not text:
return False
header_pattern = r"^#{1,6}\s+(.+)$"
paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$"
content_lines: List[str] = []
for line in text.split("\n"):
s = line.strip()
if not s:
continue
m = re.match(header_pattern, s)
if m:
content_lines.append(m.group(1))
continue
if re.match(paragraph_pattern, s):
content_lines.append(s)
if not content_lines:
return False
combined = " ".join(content_lines)
arabic = 0
total = 0
for ch in combined:
if ch.isalpha():
total += 1
if (
("\u0600" <= ch <= "\u06ff")
or ("\u0750" <= ch <= "\u077f")
or ("\u08a0" <= ch <= "\u08ff")
):
arabic += 1
if total == 0:
return False
return (arabic / total) > 0.5
def extract_json(text: str) -> Optional[Dict[str, Any]]:
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
brace_start = text.find("{")
brace_end = text.rfind("}")
if 0 <= brace_start < brace_end:
snippet = text[brace_start : brace_end + 1]
try:
return json.loads(snippet)
except Exception:
pass
fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text)
for block in fenced:
try:
return json.loads(block)
except Exception:
continue
return None
model: Optional[AutoModelForCausalLM] = None
processor: Optional[AutoProcessor] = None
def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]:
global model, processor
if model is not None and processor is not None:
return model, processor
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
snapshot_download(
repo_id=DOTS_REPO_ID,
local_dir=DOTS_LOCAL_DIR,
local_dir_use_symlinks=False,
)
# Work around transformers dynamic module parent package issue with repo name containing a dot
# Ensure 'transformers_modules' and 'transformers_modules.dots' exist as packages
if "transformers_modules" not in sys.modules:
pkg = types.ModuleType("transformers_modules")
pkg.__path__ = [] # type: ignore[attr-defined]
sys.modules["transformers_modules"] = pkg
if "transformers_modules.dots" not in sys.modules:
subpkg = types.ModuleType("transformers_modules.dots")
subpkg.__path__ = [] # type: ignore[attr-defined]
sys.modules["transformers_modules.dots"] = subpkg
use_mps = torch.backends.mps.is_available()
dtype = (
torch.float16
if use_mps
else (torch.bfloat16 if torch.cuda.is_available() else torch.float32)
)
model = AutoModelForCausalLM.from_pretrained(
DOTS_LOCAL_DIR,
torch_dtype=dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
if use_mps:
model.to("mps")
proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True)
processor = proc
return model, processor
def run_inference(
image: Image.Image,
prompt_text: str,
max_new_tokens: int = LOCAL_DEFAULT_MAX_NEW_TOKENS,
) -> str:
mdl, proc = ensure_model_loaded()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text},
],
}
]
text = proc.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = proc(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
device = (
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.no_grad():
generated_ids = mdl.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False,
temperature=0.1,
)
trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def process_single_image(
image: Image.Image,
prompt_text: str,
max_new_tokens: int,
) -> Dict[str, Any]:
img = fetch_image(image)
raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens)
result: Dict[str, Any] = {
"original_image": img,
"processed_image": img,
"raw_output": raw,
"layout_result": None,
"markdown": None,
}
data = extract_json(raw)
if isinstance(data, dict):
result["layout_result"] = data
items = data.get("elements", data.get("elements_list", data.get("content", [])))
if isinstance(items, list):
result["processed_image"] = draw_layout_on_image(img, items)
result["markdown"] = layoutjson2md(img, items)
if result["markdown"] is None:
result["markdown"] = raw
return result
def layoutjson2md(
image: Image.Image, layout_data: List[Dict], text_key: str = "text"
) -> str:
lines: List[str] = []
try:
items = sorted(
layout_data,
key=lambda x: (
x.get("bbox", [0, 0, 0, 0])[1],
x.get("bbox", [0, 0, 0, 0])[0],
),
)
for item in items:
category = item.get("category", "")
text = item.get(text_key, "")
if category == "Title" and text:
lines.append(f"# {text}\n")
elif category == "Section-header" and text:
lines.append(f"## {text}\n")
elif category == "List-item" and text:
lines.append(f"- {text}\n")
elif category == "Table" and text:
if text.strip().startswith("<"):
lines.append(text + "\n")
else:
lines.append(f"**Table:** {text}\n")
elif category == "Formula" and text:
if text.strip().startswith("$") or "\\" in text:
lines.append(f"$$\n{text}\n$$\n")
else:
lines.append(f"**Formula:** {text}\n")
elif category == "Caption" and text:
lines.append(f"*{text}*\n")
elif category in ["Page-header", "Page-footer"]:
continue
elif category == "Picture":
continue
elif text:
lines.append(f"{text}\n")
lines.append("")
except Exception:
return json.dumps(layout_data, ensure_ascii=False)
return "\n".join(lines)
def create_blocks_app():
css = """
.main-container { max-width: 1500px; margin: 0 auto; }
.header-text { text-align: center; color: #1f2937; margin-bottom: 12px; }
.page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; }
.process-button { border: none !important; color: white !important; font-weight: 700 !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo:
doc_state = gr.State(
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
}
)
cache_state = gr.State({})
gr.HTML(
"""
<div class=\"header-text\">
<h2>VLM Playground — dots.ocr (Local)</h2>
<p>Optimized defaults for Apple Silicon / CPU dev.</p>
</div>
"""
)
with gr.Row(elem_classes=["main-container"]):
with gr.Column(scale=4):
file_input = gr.File(
label="Upload PDF or Image",
file_types=[
".pdf",
".png",
".jpg",
".jpeg",
".bmp",
".tiff",
".webp",
],
type="filepath",
)
with gr.Group():
template = gr.Dropdown(
label="Prompt Template",
choices=["Layout Extraction"],
value="Layout Extraction",
)
prompt_text = gr.Textbox(
label="Current Prompt",
value=dict_promptmode_to_prompt.get("prompt_layout_all_en", ""),
lines=6,
)
with gr.Row():
parse_button = gr.Button(
"Parse", variant="primary", elem_classes=["process-button"]
)
clear_button = gr.Button("Clear")
with gr.Accordion("Advanced", open=False):
max_new_tokens = gr.Slider(
minimum=256,
maximum=8192,
value=LOCAL_DEFAULT_MAX_NEW_TOKENS,
step=128,
label="Max new tokens",
)
page_range = gr.Textbox(
label="Page selection",
placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)",
)
with gr.Column(scale=5):
preview_image = gr.Image(label="Page Preview", type="pil", height=520)
with gr.Row():
prev_btn = gr.Button("◀ Prev")
page_info = gr.HTML('<div class="page-info">No file</div>')
next_btn = gr.Button("Next ▶")
with gr.Row():
page_jump = gr.Number(value=1, label="Page #", precision=0)
jump_btn = gr.Button("Go")
with gr.Column(scale=6):
with gr.Tabs():
with gr.Tab("Markdown Render"):
md_render = gr.Markdown(
value="Upload and parse to view results", height=520
)
with gr.Tab("Raw Markdown"):
md_raw = gr.Textbox(value="", lines=20)
with gr.Tab("Current Page JSON"):
json_view = gr.JSON(value=None)
with gr.Tab("Processed Image"):
processed_view = gr.Image(type="pil", height=520)
with gr.Row():
download_jsonl = gr.DownloadButton(label="Download JSONL")
download_markdown = gr.DownloadButton(label="Download Markdown")
def on_template_change(choice: str) -> str:
return dict_promptmode_to_prompt.get("prompt_layout_all_en", "")
def on_file_change(path: Optional[str]):
if not path or not os.path.exists(path):
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
)
checksum = file_checksum(path)
ext = os.path.splitext(path)[1].lower()
if ext == ".pdf":
images = load_images_from_pdf(path)
state = {
"images": images,
"current_page": 0,
"total_pages": len(images),
"file_type": "pdf",
"checksum": checksum,
"results": [None] * len(images),
"parsed": False,
}
return (
state,
images[0] if images else None,
f'<div class="page-info">Page 1 / {len(images)}</div>',
)
else:
image = Image.open(path).convert("RGB")
state = {
"images": [image],
"current_page": 0,
"total_pages": 1,
"file_type": "image",
"checksum": checksum,
"results": [None],
"parsed": False,
}
return state, image, '<div class="page-info">Page 1 / 1</div>'
def nav_page(state: Dict[str, Any], direction: str):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
if direction == "prev":
state["current_page"] = max(0, state["current_page"] - 1)
elif direction == "next":
state["current_page"] = min(
state["total_pages"] - 1, state["current_page"] + 1
)
idx = state["current_page"]
img = state["images"][idx]
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
result = (
state["results"][idx]
if state.get("parsed") and idx < len(state["results"])
else None
)
md = result.get("markdown") if result else "Page not processed yet"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = result.get("processed_image") if result else None
js = result.get("layout_result") if result else None
return state, img, info, md_out, md_raw_text, proc_img, js
def jump_to_page(state: Dict[str, Any], page_num: Any):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
try:
n = int(page_num)
except Exception:
n = 1
n = max(1, min(state["total_pages"], n))
state["current_page"] = n - 1
return nav_page(state, direction="stay")
def parse_pages(
state: Dict[str, Any],
prompt: str,
max_tokens: int,
selection: Optional[str],
):
if not state.get("images"):
return state, None, "No file", "No content", "", None, None
indices: List[int] = []
if not selection or selection.strip() == "":
indices = [state["current_page"]]
elif selection.strip().lower() == "all":
indices = list(range(state["total_pages"]))
else:
parts = [p.strip() for p in selection.split(",") if p.strip()]
for p in parts:
if "-" in p:
a, b = p.split("-", 1)
try:
a_i = max(1, int(a))
b_i = min(state["total_pages"], int(b))
for i in range(a_i - 1, b_i):
indices.append(i)
except Exception:
continue
else:
try:
i = max(1, min(state["total_pages"], int(p)))
indices.append(i - 1)
except Exception:
continue
indices = sorted(
set([i for i in indices if 0 <= i < state["total_pages"]])
)
results = state.get("results") or [None] * state["total_pages"]
for i in indices:
img = state["images"][i]
prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16]
cache_key = (
state["checksum"],
i,
prompt_hash,
int(max_tokens),
)
cached = cache_state.value.get(cache_key)
if cached:
results[i] = cached
continue
res = process_single_image(
img,
prompt_text=prompt,
max_new_tokens=int(max_tokens),
)
results[i] = res
cache_state.value[cache_key] = res
state["results"] = results
state["parsed"] = True
idx = state["current_page"]
curr = results[idx]
md = curr.get("markdown") if curr else "No content"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = curr.get("processed_image") if curr else None
js = curr.get("layout_result") if curr else None
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
prev = state["images"][idx]
return state, prev, info, md_out, md_raw_text, proc_img, js
def clear_all():
gc.collect()
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
"Upload and parse to view results",
"",
None,
None,
)
def download_current_jsonl(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
lines: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("layout_result") is not None:
obj = {"page": i + 1, "layout": res["layout_result"]}
lines.append(json.dumps(obj, ensure_ascii=False))
content = "\n".join(lines) if lines else ""
out_path = os.path.join(TMP_DIR, "results.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
def download_current_markdown(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
chunks: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("markdown"):
chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
content = "\n\n---\n\n".join(chunks) if chunks else ""
out_path = os.path.join(TMP_DIR, "results.md")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
template.change(on_template_change, inputs=[template], outputs=[prompt_text])
file_input.change(
on_file_change,
inputs=[file_input],
outputs=[doc_state, preview_image, page_info],
)
prev_btn.click(
lambda s: nav_page(s, "prev"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
next_btn.click(
lambda s: nav_page(s, "next"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
jump_btn.click(
jump_to_page,
inputs=[doc_state, page_jump],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
parse_button.click(
parse_pages,
inputs=[doc_state, prompt_text, max_new_tokens, page_range],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
clear_button.click(
clear_all,
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
download_jsonl.click(
download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl]
)
download_markdown.click(
download_current_markdown, inputs=[doc_state], outputs=[download_markdown]
)
return demo