import gc import hashlib import json import math import os import re from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import fitz # PyMuPDF import gradio as gr import requests import torch from huggingface_hub import snapshot_download from PIL import Image, ImageDraw, ImageFont from qwen_vl_utils import process_vision_info from transformers import AutoModelForCausalLM, AutoProcessor from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS from .utils.prompts import dict_promptmode_to_prompt # ============================ # Constants and configuration # ============================ APP_TITLE = "PreviewSpace — VLM Playground" TMP_DIR = "/tmp/previewspace" MODELS_DIR = os.path.join(TMP_DIR, "models") DOTS_REPO_ID = "rednote-hilab/dots.ocr" DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr") DEFAULT_PROMPT = dict_promptmode_to_prompt.get( "prompt_layout_all_en", ( "Please output the layout information from the PDF page image. For each element, return: " 'bbox: [x1, y1, x2, y2], category from {"title","header","paragraph","table","figure","footnote"}, and text. ' 'Return JSON: {"elements": [{"bbox": [..], "category": "..", "text": ".."}], "page": }' ), ) os.makedirs(TMP_DIR, exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True) # =========== # Utilities # =========== def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, ) -> Tuple[int, int]: if max(height, width) / min(height, width) > 200: raise ValueError("absolute aspect ratio must be smaller than 200") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return int(h_bar), int(w_bar) def fetch_image( image_input: Any, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ) -> Image.Image: if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input, timeout=60) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise ValueError(f"Invalid image input type: {type(image_input)}") if min_pixels is not None or max_pixels is not None: min_pixels = min_pixels or MIN_PIXELS max_pixels = max_pixels or MAX_PIXELS new_h, new_w = smart_resize( image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels, ) image = image.resize((new_w, new_h), Image.LANCZOS) return image def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: images: List[Image.Image] = [] pdf_document = fitz.open(pdf_path) try: for page_idx in range(len(pdf_document)): page = pdf_document.load_page(page_idx) pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) img_data = pix.tobytes("ppm") image = Image.open(BytesIO(img_data)).convert("RGB") images.append(image) finally: pdf_document.close() return images def file_checksum(path: str, chunk_size: int = 1 << 20) -> str: hasher = hashlib.sha256() with open(path, "rb") as f: while True: chunk = f.read(chunk_size) if not chunk: break hasher.update(chunk) return hasher.hexdigest() def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: img = image.copy() draw = ImageDraw.Draw(img) colors = { "Caption": "#FF6B6B", "Footnote": "#4ECDC4", "Formula": "#45B7D1", "List-item": "#96CEB4", "Page-footer": "#FFEAA7", "Page-header": "#DDA0DD", "Picture": "#FFD93D", "Section-header": "#6C5CE7", "Table": "#FD79A8", "Text": "#74B9FF", "Title": "#E17055", } try: try: font = ImageFont.truetype( "/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12 ) except Exception: try: font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12 ) except Exception: font = ImageFont.load_default() for item in layout_data: bbox = item.get("bbox") category = item.get("category") if not bbox or not category: continue color = colors.get(category, "#000000") draw.rectangle(bbox, outline=color, width=2) label = str(category) label_bbox = draw.textbbox((0, 0), label, font=font) label_w = label_bbox[2] - label_bbox[0] label_h = label_bbox[3] - label_bbox[1] x1, y1 = int(bbox[0]), int(bbox[1]) lx = x1 ly = max(0, y1 - label_h - 2) draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color) draw.text((lx + 2, ly + 1), label, fill="white", font=font) except Exception: pass return img def is_arabic_text(text: str) -> bool: if not text: return False header_pattern = r"^#{1,6}\s+(.+)$" paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$" content_lines: List[str] = [] for line in text.split("\n"): s = line.strip() if not s: continue m = re.match(header_pattern, s) if m: content_lines.append(m.group(1)) continue if re.match(paragraph_pattern, s): content_lines.append(s) if not content_lines: return False combined = " ".join(content_lines) arabic = 0 total = 0 for ch in combined: if ch.isalpha(): total += 1 if ( ("\u0600" <= ch <= "\u06ff") or ("\u0750" <= ch <= "\u077f") or ("\u08a0" <= ch <= "\u08ff") ): arabic += 1 if total == 0: return False return (arabic / total) > 0.5 def extract_json(text: str) -> Optional[Dict[str, Any]]: if not text: return None try: return json.loads(text) except Exception: pass # Try to extract JSON block brace_start = text.find("{") brace_end = text.rfind("}") if 0 <= brace_start < brace_end: snippet = text[brace_start : brace_end + 1] try: return json.loads(snippet) except Exception: pass fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text) for block in fenced: try: return json.loads(block) except Exception: continue return None def layoutjson2md( image: Image.Image, layout_data: List[Dict], text_key: str = "text" ) -> str: lines: List[str] = [] try: items = sorted( layout_data, key=lambda x: ( x.get("bbox", [0, 0, 0, 0])[1], x.get("bbox", [0, 0, 0, 0])[0], ), ) for item in items: category = item.get("category", "") text = item.get(text_key, "") if category == "Title" and text: lines.append(f"# {text}\n") elif category == "Section-header" and text: lines.append(f"## {text}\n") elif category == "List-item" and text: lines.append(f"- {text}\n") elif category == "Table" and text: if text.strip().startswith("<"): lines.append(text + "\n") else: lines.append(f"**Table:** {text}\n") elif category == "Formula" and text: if text.strip().startswith("$") or "\\" in text: lines.append(f"$$\n{text}\n$$\n") else: lines.append(f"**Formula:** {text}\n") elif category == "Caption" and text: lines.append(f"*{text}*\n") elif category in ["Page-header", "Page-footer"]: continue elif category == "Picture": # Skip embedding image fragments in markdown for now continue elif text: lines.append(f"{text}\n") lines.append("") except Exception: return json.dumps(layout_data, ensure_ascii=False) return "\n".join(lines) # ===================== # Model initialization # ===================== model: Optional[AutoModelForCausalLM] = None processor: Optional[AutoProcessor] = None device = ( "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ) def get_torch_dtype() -> torch.dtype: if device == "cuda": return torch.bfloat16 if device == "mps": return torch.float16 return torch.float32 def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]: global model, processor if model is not None and processor is not None: return model, processor os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") snapshot_download( repo_id=DOTS_REPO_ID, local_dir=DOTS_LOCAL_DIR, local_dir_use_symlinks=False, ) dtype = get_torch_dtype() model = AutoModelForCausalLM.from_pretrained( DOTS_LOCAL_DIR, torch_dtype=dtype, device_map="auto", trust_remote_code=True, ) proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True) processor = proc return model, processor def run_inference( image: Image.Image, prompt_text: str, max_new_tokens: int = 24000 ) -> str: mdl, proc = ensure_model_loaded() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text}, ], } ] text = proc.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = proc( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} with torch.no_grad(): generated_ids = mdl.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=False, temperature=0.1, ) trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] output_text = processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] if output_text else "" def process_single_image( image: Image.Image, prompt_text: str, min_pixels: Optional[int], max_pixels: Optional[int], max_new_tokens: int, ) -> Dict[str, Any]: img = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens) result: Dict[str, Any] = { "original_image": img, "processed_image": img, "raw_output": raw, "layout_result": None, "markdown": None, } data = extract_json(raw) if isinstance(data, dict): result["layout_result"] = data items = data.get("elements", data.get("elements_list", data.get("content", []))) if isinstance(items, list): result["processed_image"] = draw_layout_on_image(img, items) result["markdown"] = layoutjson2md(img, items) if result["markdown"] is None: result["markdown"] = raw return result # ================= # Gradio Interface # ================= def create_blocks_app(): css = """ .main-container { max-width: 1500px; margin: 0 auto; } .header-text { text-align: center; color: #1f2937; margin-bottom: 12px; } .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; } .process-button { border: none !important; color: white !important; font-weight: 700 !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo: # App state doc_state = gr.State( { "images": [], "current_page": 0, "total_pages": 0, "file_type": None, "checksum": None, "results": [], "parsed": False, } ) cache_state = gr.State({}) # (checksum, page, prompt_hash) -> result gr.HTML( """

VLM Playground — dots.ocr

Upload a PDF or image, preview pages, and parse with a layout-extraction prompt.

""" ) with gr.Row(elem_classes=["main-container"]): # Left: upload + controls with gr.Column(scale=4): file_input = gr.File( label="Upload PDF or Image", file_types=[ ".pdf", ".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp", ], type="filepath", ) with gr.Group(): template = gr.Dropdown( label="Prompt Template", choices=["Layout Extraction"], value="Layout Extraction", ) prompt_text = gr.Textbox( label="Current Prompt", value=DEFAULT_PROMPT, lines=6, ) with gr.Row(): parse_button = gr.Button( "Parse", variant="primary", elem_classes=["process-button"] ) clear_button = gr.Button("Clear") with gr.Accordion("Advanced", open=False): max_new_tokens = gr.Slider( minimum=512, maximum=32000, value=24000, step=256, label="Max new tokens", ) min_pixels_in = gr.Number(value=MIN_PIXELS, label="Min pixels") max_pixels_in = gr.Number(value=MAX_PIXELS, label="Max pixels") page_range = gr.Textbox( label="Page selection", placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)", ) # Center: page preview + nav with gr.Column(scale=5): preview_image = gr.Image(label="Page Preview", type="pil", height=520) with gr.Row(): prev_btn = gr.Button("◀ Prev") page_info = gr.HTML('
No file
') next_btn = gr.Button("Next ▶") with gr.Row(): page_jump = gr.Number(value=1, label="Page #", precision=0) jump_btn = gr.Button("Go") # Right: results with gr.Column(scale=6): with gr.Tabs(): with gr.Tab("Markdown Render"): md_render = gr.Markdown( value="Upload and parse to view results", height=520 ) with gr.Tab("Raw Markdown"): md_raw = gr.Textbox(value="", lines=20) with gr.Tab("Current Page JSON"): json_view = gr.JSON(value=None) with gr.Tab("Processed Image"): processed_view = gr.Image(type="pil", height=520) with gr.Row(): download_jsonl = gr.DownloadButton(label="Download JSONL") download_markdown = gr.DownloadButton(label="Download Markdown") # ===== Handlers ===== def on_template_change(choice: str) -> str: return DEFAULT_PROMPT def on_file_change(path: Optional[str]): if not path or not os.path.exists(path): return ( { "images": [], "current_page": 0, "total_pages": 0, "file_type": None, "checksum": None, "results": [], "parsed": False, }, None, '
No file
', ) checksum = file_checksum(path) ext = os.path.splitext(path)[1].lower() if ext == ".pdf": images = load_images_from_pdf(path) state = { "images": images, "current_page": 0, "total_pages": len(images), "file_type": "pdf", "checksum": checksum, "results": [None] * len(images), "parsed": False, } return ( state, images[0] if images else None, f'
Page 1 / {len(images)}
', ) else: image = Image.open(path).convert("RGB") state = { "images": [image], "current_page": 0, "total_pages": 1, "file_type": "image", "checksum": checksum, "results": [None], "parsed": False, } return state, image, '
Page 1 / 1
' def nav_page(state: Dict[str, Any], direction: str): if not state.get("images"): return ( state, None, '
No file
', "No results", "", None, None, ) if direction == "prev": state["current_page"] = max(0, state["current_page"] - 1) elif direction == "next": state["current_page"] = min( state["total_pages"] - 1, state["current_page"] + 1 ) idx = state["current_page"] img = state["images"][idx] info = ( f'
Page {idx + 1} / {state["total_pages"]}
' ) result = ( state["results"][idx] if state.get("parsed") and idx < len(state["results"]) else None ) md = result.get("markdown") if result else "Page not processed yet" md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md md_raw_text = md proc_img = result.get("processed_image") if result else None js = result.get("layout_result") if result else None return state, img, info, md_out, md_raw_text, proc_img, js def jump_to_page(state: Dict[str, Any], page_num: Any): if not state.get("images"): return ( state, None, '
No file
', "No results", "", None, None, ) try: n = int(page_num) except Exception: n = 1 n = max(1, min(state["total_pages"], n)) state["current_page"] = n - 1 return nav_page(state, direction="stay") def parse_pages( state: Dict[str, Any], prompt: str, max_tokens: int, min_pix: Optional[float], max_pix: Optional[float], selection: Optional[str], ): if not state.get("images"): return state, None, "No file", "No content", "", None, None # Determine pages to process indices: List[int] = [] if not selection or selection.strip() == "": indices = [state["current_page"]] elif selection.strip().lower() == "all": indices = list(range(state["total_pages"])) else: # parse like 1-3,5 parts = [p.strip() for p in selection.split(",") if p.strip()] for p in parts: if "-" in p: a, b = p.split("-", 1) try: a_i = max(1, int(a)) b_i = min(state["total_pages"], int(b)) for i in range(a_i - 1, b_i): indices.append(i) except Exception: continue else: try: i = max(1, min(state["total_pages"], int(p))) indices.append(i - 1) except Exception: continue indices = sorted( set([i for i in indices if 0 <= i < state["total_pages"]]) ) # Process sequentially for stability results = state.get("results") or [None] * state["total_pages"] for i in indices: img = state["images"][i] prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16] cache_key = ( state["checksum"], i, prompt_hash, int(min_pix or 0), int(max_pix or 0), int(max_tokens), ) cached = cache_state.value.get(cache_key) if cached: results[i] = cached continue res = process_single_image( img, prompt_text=prompt, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None, max_new_tokens=int(max_tokens), ) results[i] = res cache_state.value[cache_key] = res state["results"] = results state["parsed"] = True # Return current page outputs idx = state["current_page"] curr = results[idx] md = curr.get("markdown") if curr else "No content" md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md md_raw_text = md proc_img = curr.get("processed_image") if curr else None js = curr.get("layout_result") if curr else None info = ( f'
Page {idx + 1} / {state["total_pages"]}
' ) prev = state["images"][idx] return state, prev, info, md_out, md_raw_text, proc_img, js def clear_all(): gc.collect() return ( { "images": [], "current_page": 0, "total_pages": 0, "file_type": None, "checksum": None, "results": [], "parsed": False, }, None, '
No file
', "Upload and parse to view results", "", None, None, ) def download_current_jsonl(state: Dict[str, Any]): if not state.get("parsed"): return gr.DownloadButton.update(value=b"") lines: List[str] = [] for i, res in enumerate(state.get("results", [])): if res and res.get("layout_result") is not None: obj = {"page": i + 1, "layout": res["layout_result"]} lines.append(json.dumps(obj, ensure_ascii=False)) content = "\n".join(lines) if lines else "" out_path = os.path.join(TMP_DIR, "results.jsonl") with open(out_path, "w", encoding="utf-8") as f: f.write(content) return gr.DownloadButton.update(value=out_path) def download_current_markdown(state: Dict[str, Any]): if not state.get("parsed"): return gr.DownloadButton.update(value=b"") chunks: List[str] = [] for i, res in enumerate(state.get("results", [])): if res and res.get("markdown"): chunks.append(f"## Page {i + 1}\n\n{res['markdown']}") content = "\n\n---\n\n".join(chunks) if chunks else "" out_path = os.path.join(TMP_DIR, "results.md") with open(out_path, "w", encoding="utf-8") as f: f.write(content) return gr.DownloadButton.update(value=out_path) # Wire events template.change(on_template_change, inputs=[template], outputs=[prompt_text]) file_input.change( on_file_change, inputs=[file_input], outputs=[doc_state, preview_image, page_info], ) prev_btn.click( lambda s: nav_page(s, "prev"), inputs=[doc_state], outputs=[ doc_state, preview_image, page_info, md_render, md_raw, processed_view, json_view, ], ) next_btn.click( lambda s: nav_page(s, "next"), inputs=[doc_state], outputs=[ doc_state, preview_image, page_info, md_render, md_raw, processed_view, json_view, ], ) jump_btn.click( jump_to_page, inputs=[doc_state, page_jump], outputs=[ doc_state, preview_image, page_info, md_render, md_raw, processed_view, json_view, ], ) parse_button.click( parse_pages, inputs=[ doc_state, prompt_text, max_new_tokens, min_pixels_in, max_pixels_in, page_range, ], outputs=[ doc_state, preview_image, page_info, md_render, md_raw, processed_view, json_view, ], ) clear_button.click( clear_all, outputs=[ doc_state, preview_image, page_info, md_render, md_raw, processed_view, json_view, ], ) download_jsonl.click( download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl] ) download_markdown.click( download_current_markdown, inputs=[doc_state], outputs=[download_markdown] ) return demo