| """ |
| Core processing utilities for DocGenie document generation pipeline. |
| |
| Integrated functionality (All 19 Stages): |
| - Stage 1-2: Seed selection, LLM prompting, response processing, PDF rendering, bbox extraction |
| - Stage 3: Handwriting & visual element synthesis (WordStylist diffusion, stamps, barcodes, logos) |
| - Stage 4: Image finalization & OCR (pdf2image, Microsoft Document Intelligence) |
| - Stage 5: Dataset packaging (bbox normalization, GT verification, analysis, debug viz) |
| |
| References generationfolder for core pipeline logic. |
| """ |
| import asyncio |
| import base64 |
| import json |
| import pathlib |
| import tempfile |
| import time |
| import uuid |
| import re |
| from typing import List, Tuple, Optional, Dict, Any |
| from io import BytesIO |
|
|
| import requests |
| import httpx |
| from PIL import Image |
| from pdf2image import convert_from_path |
| from bs4 import BeautifulSoup |
| from playwright.async_api import async_playwright |
| import fitz |
|
|
| from docgenie.generation.constants import BS_PARSER, HANDWRITING_CLASS_NAME, VISUAL_ELEMENT_TYPE_SYNONYMS |
| from docgenie.generation.pipeline_01.claude_batching import ClaudeBatchedClient, create_message |
| from docgenie.generation.pipeline_03_process_response import ( |
| extract_html_documents_from_text, |
| extract_gt, |
| ) |
| from docgenie.generation.pipeline_03.css import ( |
| increase_handwriting_font_size, |
| unmark_visual_elements, |
| ) |
| from docgenie.generation.pipeline_04_render_pdf_and_extract_geos import ( |
| render_pdf_async, |
| preprocess_html_for_pdf, |
| ) |
| from docgenie.generation.pipeline_04.extract_bbox import extract_bboxes_from_pdf |
|
|
| |
| |
| |
| from docgenie.generation.utils.pdfjs import MEASURE_DIMENSIONS |
| from docgenie.generation.utils.stamp import create_stamp |
| from docgenie import ENV |
|
|
| |
| from .config import settings |
|
|
|
|
| async def download_image_to_base64(url: str) -> str: |
| """ |
| Download image or PDF from URL and convert to base64 JPEG. |
| If URL points to a PDF, converts the first page to an image. |
| |
| Args: |
| url: Image or PDF URL |
| |
| Returns: |
| Base64-encoded JPEG image string |
| """ |
| response = requests.get(url, timeout=30) |
| response.raise_for_status() |
| |
| content_type = response.headers.get('Content-Type', '').lower() |
| is_pdf = 'application/pdf' in content_type or url.lower().endswith('.pdf') |
| |
| if is_pdf: |
| |
| print(f" 📄 Detected PDF, converting first page to image: {url[:80]}...") |
| |
| |
| pdf_document = fitz.open(stream=response.content, filetype="pdf") |
| |
| if len(pdf_document) == 0: |
| raise ValueError("PDF has no pages") |
| |
| |
| page = pdf_document[0] |
| |
| zoom = 300 / 72 |
| mat = fitz.Matrix(zoom, zoom) |
| pix = page.get_pixmap(matrix=mat) |
| |
| |
| img_data = pix.tobytes("png") |
| img = Image.open(BytesIO(img_data)) |
| |
| pdf_document.close() |
| |
| print(f" ✓ Converted PDF to image: {img.size[0]}x{img.size[1]}px") |
| else: |
| |
| img = Image.open(BytesIO(response.content)) |
| |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| buffer = BytesIO() |
| img.save(buffer, format='JPEG', quality=95) |
| buffer.seek(0) |
| |
| |
| img_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
| return img_base64 |
|
|
|
|
| def download_seed_images(urls: List[str]) -> List[str]: |
| """ |
| Download multiple seed images/PDFs and convert to base64 (synchronous version for worker). |
| If a URL points to a PDF, converts the first page to an image. |
| Implements retry logic for transient HTTP errors (503, 502, 504, 429). |
| |
| Args: |
| urls: List of image or PDF URLs |
| |
| Returns: |
| List of base64-encoded JPEG image strings |
| """ |
| images = [] |
| for url in urls: |
| |
| max_retries = 3 |
| response = None |
| |
| for attempt in range(max_retries): |
| try: |
| response = requests.get(url, timeout=30) |
| response.raise_for_status() |
| break |
| |
| except requests.exceptions.HTTPError as e: |
| |
| if e.response.status_code in [502, 503, 504, 429]: |
| if attempt < max_retries - 1: |
| wait_time = 2 * (2 ** attempt) |
| print(f" ⚠️ HTTP {e.response.status_code} error downloading seed image, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})...") |
| time.sleep(wait_time) |
| continue |
| |
| raise |
| except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: |
| if attempt < max_retries - 1: |
| wait_time = 2 * (2 ** attempt) |
| print(f" ⚠️ Network error downloading seed image, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries}): {e}") |
| time.sleep(wait_time) |
| continue |
| raise |
| |
| if response is None: |
| raise Exception(f"Failed to download seed image after {max_retries} attempts") |
| |
| content_type = response.headers.get('Content-Type', '').lower() |
| is_pdf = 'application/pdf' in content_type or url.lower().endswith('.pdf') |
| |
| if is_pdf: |
| |
| print(f" 📄 Detected PDF, converting first page to image: {url[:80]}...") |
| |
| |
| pdf_document = fitz.open(stream=response.content, filetype="pdf") |
| |
| if len(pdf_document) == 0: |
| raise ValueError("PDF has no pages") |
| |
| |
| page = pdf_document[0] |
| |
| zoom = 300 / 72 |
| mat = fitz.Matrix(zoom, zoom) |
| pix = page.get_pixmap(matrix=mat) |
| |
| |
| img_data = pix.tobytes("png") |
| img = Image.open(BytesIO(img_data)) |
| |
| pdf_document.close() |
| |
| print(f" ✓ Converted PDF to image: {img.size[0]}x{img.size[1]}px") |
| else: |
| |
| img = Image.open(BytesIO(response.content)) |
| |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| buffer = BytesIO() |
| img.save(buffer, format='JPEG', quality=95) |
| buffer.seek(0) |
| |
| |
| img_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
| images.append(img_base64) |
| |
| return images |
|
|
|
|
| def build_prompt( |
| language: str, |
| doc_type: str, |
| gt_type: str, |
| gt_format: str, |
| num_solutions: int, |
| num_seed_images: int, |
| prompt_template_path: pathlib.Path, |
| enable_visual_elements: bool = True, |
| visual_element_types: List[str] = None |
| ) -> str: |
| """ |
| Build the system prompt by injecting parameters into template. |
| |
| Args: |
| language: Language for documents |
| doc_type: Type of documents |
| gt_type: Ground truth type description |
| gt_format: Ground truth format specification |
| num_solutions: Number of documents to generate |
| num_seed_images: Number of seed images provided |
| prompt_template_path: Path to prompt template file |
| enable_visual_elements: Whether to include visual element instructions |
| visual_element_types: List of allowed visual element types |
| |
| Returns: |
| Formatted prompt string |
| """ |
| template = prompt_template_path.read_text(encoding='utf-8') |
| |
| |
| import re |
| |
| |
| ve_block_pattern = r"## Visual Placeholders \(if document type requires\)\n(.*?)\n\n" |
| |
| if not enable_visual_elements or not visual_element_types: |
| |
| template = re.sub(ve_block_pattern, "", template, flags=re.DOTALL) |
| |
| template = template.replace("- [ ] Visual elements are semantically coherent\n", "") |
| else: |
| |
| types_str = ", ".join(visual_element_types) |
| |
| |
| EXAMPLES = { |
| "stamp": '- Example: `<div data-placeholder="stamp" data-content="APPROVED 2024-03-15" style="position:absolute;top:50mm;right:20mm;width:35mm;height:35mm;z-index:10;"></div>`', |
| "logo": '- Example: `<div data-placeholder="logo" data-content="ACME Corp Logo" style="width:150mm;height:100mm;"></div>`', |
| "figure": '- Example: `<div data-placeholder="figure" data-content="Sales Chart 2023" style="width:120mm;height:80mm;"></div>`', |
| "barcode": '- Example: `<div data-placeholder="barcode" data-content="SKU-12345678" style="width:60mm;height:25mm;"></div>`', |
| "photo": '- Example: `<div data-placeholder="photo" data-content="Customer Portrait" style="width:40mm;height:50mm;"></div>`' |
| } |
| |
| |
| selected_examples = [] |
| for t in visual_element_types: |
| if t in EXAMPLES: |
| selected_examples.append(EXAMPLES[t]) |
| if len(selected_examples) >= 2: |
| break |
| |
| |
| if len(selected_examples) == 0: |
| selected_examples = [EXAMPLES["logo"], EXAMPLES["stamp"]] |
| |
| new_block = [ |
| "## Visual Placeholders (if document type requires)", |
| "- Insert `<div data-placeholder=\"type\" style=\"...\">` for non-text elements at appropriate positions", |
| f"- Valid types are: {types_str}", |
| "- Add data-content attribute with actual content description", |
| "- For stamps, use `position:absolute;z-index:10;` and specify 'top' and 'right'" if "stamp" in visual_element_types else None, |
| "- Always provide appropiate dimensions", |
| ] |
| |
| new_block.extend(selected_examples) |
| |
| |
| new_block_str = "\n".join([line for line in new_block if line is not None]) + "\n\n" |
| |
| template = re.sub(ve_block_pattern, new_block_str, template, flags=re.DOTALL) |
|
|
| |
| prompt = template.format( |
| language=language, |
| doc_type=doc_type, |
| gt_type=gt_type, |
| gt_format=gt_format, |
| num_solutions=num_solutions, |
| num_seed_images=num_seed_images |
| ) |
| |
| return prompt |
|
|
|
|
| async def call_claude_api_direct( |
| prompt: str, |
| seed_images_base64: List[str], |
| api_key: str, |
| model: str = "claude-sonnet-4-5-20250929", |
| max_tokens: int = 16384 |
| ) -> str: |
| """ |
| Call Claude API directly (non-batched) with prompt and seed images. |
| Used for API endpoint for immediate synchronous responses. |
| |
| Args: |
| prompt: System prompt |
| seed_images_base64: List of base64-encoded seed images |
| api_key: Anthropic API key |
| model: Claude model name |
| max_tokens: Maximum tokens for response |
| |
| Returns: |
| Raw LLM response text |
| """ |
| import anthropic |
| |
| client = anthropic.Anthropic(api_key=api_key) |
| |
| |
| message_content = create_message(prompt=prompt, images_base64=seed_images_base64) |
| |
| |
| message = client.messages.create( |
| model=model, |
| max_tokens=max_tokens, |
| messages=[message_content], |
| ) |
| |
| |
| response_text = "" |
| for block in message.content: |
| if block.type == "text": |
| response_text += block.text |
| |
| return response_text |
|
|
|
|
| def extract_html_documents_from_response(response_text: str) -> List[str]: |
| """ |
| Extract individual HTML documents from LLM response. |
| Uses pipeline_03 function for consistency. |
| |
| Args: |
| response_text: Raw LLM response |
| |
| Returns: |
| List of HTML document strings |
| """ |
| |
| return extract_html_documents_from_text(text=response_text) |
|
|
|
|
| def extract_ground_truth(html: str) -> Tuple[Optional[dict], str]: |
| """ |
| Extract ground truth JSON from HTML and return cleaned HTML. |
| Uses pipeline_03 function for consistency. |
| |
| Args: |
| html: HTML document with embedded GT |
| |
| Returns: |
| Tuple of (ground_truth_dict, html_without_gt) |
| """ |
| |
| raw_json, html_clean, soup = extract_gt(html=html) |
| |
| if raw_json: |
| try: |
| gt_dict = json.loads(raw_json) |
| return gt_dict, html_clean |
| except json.JSONDecodeError: |
| return None, html |
| |
| return None, html |
|
|
|
|
| def extract_css_from_html(html: str) -> Tuple[str, str]: |
| """ |
| Extract CSS from HTML and return both separately. |
| |
| Args: |
| html: HTML document |
| |
| Returns: |
| Tuple of (css_string, html_string) |
| """ |
| soup = BeautifulSoup(html, BS_PARSER) |
| |
| css_parts = [] |
| |
| |
| for style_tag in soup.find_all("style"): |
| if style_tag.string: |
| css_parts.append(style_tag.string) |
| |
| |
| for tag in soup.find_all(style=True): |
| css_parts.append(f"{tag.name} {{ {tag['style']} }}") |
| |
| css = "\n".join(css_parts) |
| return css, html |
|
|
|
|
| |
|
|
|
|
| async def render_html_to_pdf( |
| html: str, |
| output_pdf_path: pathlib.Path, |
| timeout_seconds: int = 60 |
| ) -> Tuple[pathlib.Path, float, float, List[dict]]: |
| """ |
| Render HTML to PDF using Playwright with automatic size detection. |
| Also extracts element geometries for handwriting and visual elements. |
| Matches pipeline_04 rendering logic. |
| |
| Args: |
| html: HTML content to render |
| output_pdf_path: Path where PDF should be saved |
| timeout_seconds: Timeout for rendering |
| |
| Returns: |
| Tuple of (pdf_path, width_mm, height_mm, geometries) |
| - geometries: List of dicts with element positions, classes, and metadata |
| """ |
| |
| html = preprocess_html_for_pdf(html) |
| soup = BeautifulSoup(html, BS_PARSER) |
| |
| |
| soup = increase_handwriting_font_size(soup, dbg=False) |
| soup = unmark_visual_elements(soup) |
| |
| prep_html = soup.prettify() |
| |
| |
| with tempfile.NamedTemporaryFile( |
| mode='w', |
| suffix='.html', |
| delete=False, |
| encoding='utf-8' |
| ) as tmp_html: |
| tmp_html.write(prep_html) |
| tmp_html_path = tmp_html.name |
| |
| try: |
| async with async_playwright() as p: |
| browser = await p.chromium.launch(headless=True) |
| page = await browser.new_page() |
| |
| |
| await page.goto( |
| f"file://{tmp_html_path}", |
| wait_until="domcontentloaded" |
| ) |
| await page.emulate_media(media="screen") |
| |
| |
| dimensions = await page.evaluate(MEASURE_DIMENSIONS) |
| |
| page_width_px = dimensions["width"] |
| page_height_px = dimensions["height"] |
| |
| |
| await page.set_viewport_size({ |
| "width": page_width_px, |
| "height": page_height_px |
| }) |
| await page.wait_for_timeout(30) |
| |
| |
| |
| selector_map = { |
| "handwriting": ".handwritten", |
| "visual_element": "[data-placeholder]", |
| "layout_element": r'[class*="LE-"]' |
| } |
| |
| |
| import json |
| selector_map_js = json.dumps(selector_map) |
| |
| |
| geo_eval_script = f""" |
| () => {{ |
| const data = []; |
| const selectorMap = {selector_map_js}; |
| const processedElements = new Map(); |
| |
| // First pass: collect all elements and their matching selectors |
| Object.entries(selectorMap).forEach(([label, selector]) => {{ |
| document.querySelectorAll(selector).forEach(el => {{ |
| if (!processedElements.has(el)) {{ |
| processedElements.set(el, []); |
| }} |
| processedElements.get(el).push(label); |
| }}); |
| }}); |
| |
| // Second pass: create geometry data for each unique element |
| processedElements.forEach((selectorTypes, el) => {{ |
| const rect = el.getBoundingClientRect(); |
| const computed = window.getComputedStyle(el); |
| |
| // Get text content |
| let text = ''; |
| if (el.tagName.toLowerCase() === 'input') {{ |
| text = (el.value || '').trim(); |
| }} else {{ |
| text = (el.innerText || el.textContent || '').trim(); |
| }} |
| |
| data.push({{ |
| id: el.id || null, |
| tag: el.tagName.toLowerCase(), |
| classes: el.className || null, |
| rect: {{ |
| x: rect.x, |
| y: rect.y, |
| width: rect.width, |
| height: rect.height |
| }}, |
| visibility: computed.visibility, |
| dataContent: el.getAttribute('data-content') || null, |
| dataPlaceholder: el.getAttribute('data-placeholder') || null, |
| style: el.getAttribute('style') || null, |
| text: text, |
| selectorTypes: selectorTypes |
| }}); |
| }}); |
| |
| return data; |
| }} |
| """ |
| |
| geometries = await page.evaluate(geo_eval_script) |
| |
| print(f" 🔍 Extracted {len(geometries)} geometries from rendered DOM") |
| |
| |
| hw_geos = [g for g in geometries if "handwriting" in g.get("selectorTypes", [])] |
| ve_geos = [g for g in geometries if "visual_element" in g.get("selectorTypes", [])] |
| if hw_geos: |
| print(f" - Found {len(hw_geos)} handwriting elements in DOM") |
| if ve_geos: |
| print(f" - Found {len(ve_geos)} visual element placeholders in DOM") |
| if not hw_geos and not ve_geos: |
| print(f" - ⚠️ No handwriting or visual elements found in DOM") |
| |
| |
| page_width_inches = page_width_px / 96 |
| page_height_inches = page_height_px / 96 |
| |
| await page.pdf( |
| path=str(output_pdf_path), |
| width=f"{page_width_inches}in", |
| height=f"{page_height_inches}in", |
| margin={ |
| "top": "0", |
| "bottom": "0", |
| "left": "0", |
| "right": "0" |
| }, |
| print_background=True, |
| display_header_footer=False, |
| prefer_css_page_size=False, |
| scale=1.0 |
| ) |
| |
| await browser.close() |
| |
| |
| width_mm = page_width_inches * 25.4 |
| height_mm = page_height_inches * 25.4 |
| |
| return output_pdf_path, width_mm, height_mm, geometries |
| |
| finally: |
| |
| pathlib.Path(tmp_html_path).unlink(missing_ok=True) |
|
|
|
|
| def extract_bboxes_from_rendered_pdf( |
| pdf_path: pathlib.Path |
| ) -> List[dict]: |
| """ |
| Extract bounding boxes from rendered PDF. |
| |
| Args: |
| pdf_path: Path to PDF file |
| |
| Returns: |
| List of bounding box dictionaries |
| """ |
| from docgenie.generation.models import OCRBox |
| |
| |
| word_bboxes = extract_bboxes_from_pdf( |
| pdf_path=pdf_path, |
| level="word" |
| ) |
| |
| |
| |
| bbox_list = [] |
| for bbox in word_bboxes: |
| bbox_list.append({ |
| "text": bbox.text, |
| "x": bbox.x0, |
| "y": bbox.y0, |
| "width": bbox.width, |
| "height": bbox.height, |
| "block_no": bbox.block_no, |
| "line_no": bbox.line_no, |
| "word_no": bbox.word_no, |
| "page": 0 |
| }) |
| |
| return bbox_list |
|
|
|
|
| def pdf_to_base64(pdf_path: pathlib.Path) -> str: |
| """ |
| Convert PDF file to base64 string. |
| |
| Args: |
| pdf_path: Path to PDF file |
| |
| Returns: |
| Base64-encoded PDF |
| """ |
| with open(pdf_path, 'rb') as f: |
| pdf_bytes = f.read() |
| |
| return base64.b64encode(pdf_bytes).decode('utf-8') |
|
|
|
|
| def validate_html_structure(html: str) -> Tuple[bool, str]: |
| """ |
| Validate HTML structure (pipeline_06 style validation). |
| |
| Args: |
| html: HTML content to validate |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| try: |
| soup = BeautifulSoup(html, BS_PARSER) |
| |
| |
| if not soup.find('html'): |
| return False, "Missing <html> tag" |
| if not soup.find('head'): |
| return False, "Missing <head> tag" |
| if not soup.find('body'): |
| return False, "Missing <body> tag" |
| |
| |
| body = soup.find('body') |
| if body and len(body.get_text(strip=True)) < 10: |
| return False, "Body content too short" |
| |
| return True, "" |
| except Exception as e: |
| return False, f"HTML parsing error: {str(e)}" |
|
|
|
|
| def validate_pdf(pdf_path: pathlib.Path) -> Tuple[bool, str]: |
| """ |
| Validate PDF file (pipeline_06 style validation). |
| |
| Args: |
| pdf_path: Path to PDF file |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| try: |
| from PyPDF2 import PdfReader |
| |
| if not pdf_path.exists(): |
| return False, "PDF file does not exist" |
| |
| |
| file_size = pdf_path.stat().st_size |
| if file_size == 0: |
| return False, "PDF file is empty" |
| if file_size > 50 * 1024 * 1024: |
| return False, f"PDF file too large: {file_size / (1024*1024):.1f}MB" |
| |
| |
| with open(pdf_path, 'rb') as f: |
| reader = PdfReader(f) |
| num_pages = len(reader.pages) |
| if num_pages == 0: |
| return False, "PDF has no pages" |
| if num_pages > 1: |
| return False, f"PDF has {num_pages} pages (expected 1)" |
| |
| return True, "" |
| except Exception as e: |
| return False, f"PDF validation error: {str(e)}" |
|
|
|
|
| def validate_bboxes(bboxes: List[dict], min_bbox_count: int = 0) -> Tuple[bool, str]: |
| """ |
| Validate bounding boxes (pipeline_06 style validation). |
| |
| Args: |
| bboxes: List of bounding box dictionaries |
| min_bbox_count: Minimum number of bboxes required |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| if len(bboxes) < min_bbox_count: |
| return False, f"Only {len(bboxes)} bboxes found (minimum {min_bbox_count} required)" |
| |
| for i, bbox in enumerate(bboxes): |
| |
| required_fields = ['text', 'x', 'y', 'width', 'height'] |
| for field in required_fields: |
| if field not in bbox: |
| return False, f"BBox {i} missing required field: {field}" |
| |
| |
| if bbox['width'] <= 0 or bbox['height'] <= 0: |
| return False, f"BBox {i} has invalid dimensions: {bbox['width']}x{bbox['height']}" |
| |
| return True, "" |
|
|
|
|
| def validate_html_structure(html: str) -> Tuple[bool, Optional[str]]: |
| """ |
| Validate HTML structure for common issues. |
| |
| Args: |
| html: HTML content to validate |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| try: |
| soup = BeautifulSoup(html, BS_PARSER) |
| |
| |
| if not soup.find('html'): |
| return False, "Missing <html> tag" |
| |
| if not soup.find('head'): |
| return False, "Missing <head> tag" |
| |
| if not soup.find('body'): |
| return False, "Missing <body> tag" |
| |
| return True, None |
| |
| except Exception as e: |
| return False, f"HTML parsing error: {str(e)}" |
|
|
|
|
| def validate_pdf(pdf_path: pathlib.Path) -> Tuple[bool, Optional[str]]: |
| """ |
| Validate PDF file for common issues. |
| |
| Args: |
| pdf_path: Path to PDF file |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| try: |
| from PyPDF2 import PdfReader |
| |
| if not pdf_path.exists(): |
| return False, "PDF file does not exist" |
| |
| if pdf_path.stat().st_size == 0: |
| return False, "PDF file is empty" |
| |
| |
| with open(pdf_path, 'rb') as f: |
| reader = PdfReader(f) |
| num_pages = len(reader.pages) |
| |
| if num_pages == 0: |
| return False, "PDF has no pages" |
| |
| if num_pages > 1: |
| return False, f"PDF has {num_pages} pages (expected 1)" |
| |
| return True, None |
| |
| except Exception as e: |
| return False, f"PDF validation error: {str(e)}" |
|
|
|
|
| def validate_bboxes(bboxes: List[dict], min_bbox_count: int = 1) -> Tuple[bool, Optional[str]]: |
| """ |
| Validate bounding boxes for common issues. |
| |
| Args: |
| bboxes: List of bounding box dictionaries |
| min_bbox_count: Minimum expected number of bboxes |
| |
| Returns: |
| Tuple of (is_valid, error_message) |
| """ |
| if len(bboxes) < min_bbox_count: |
| return False, f"Too few bboxes: {len(bboxes)} (expected at least {min_bbox_count})" |
| |
| for i, bbox in enumerate(bboxes): |
| |
| required_fields = ['text', 'x', 'y', 'width', 'height'] |
| for field in required_fields: |
| if field not in bbox: |
| return False, f"BBox {i} missing required field: {field}" |
| |
| |
| if bbox['width'] <= 0 or bbox['height'] <= 0: |
| return False, f"BBox {i} has invalid dimensions: width={bbox['width']}, height={bbox['height']}" |
| |
| return True, None |
|
|
|
|
| |
| |
| |
|
|
| async def call_handwriting_service_batch( |
| texts_with_metadata: List[Dict] |
| ) -> List[Dict]: |
| """ |
| Call RunPod handwriting service with TRUE batch processing for cost efficiency. |
| Sends all texts in ONE request to activate only ONE worker, significantly reducing costs. |
| |
| Cost comparison for 10 texts: |
| - OLD (parallel): 10 workers × 18s = 180 worker-seconds |
| - NEW (batched): 1 worker × 190s = 190 worker-seconds BUT only 1 worker activation fee |
| |
| For RunPod pricing with activation overhead, batching is ~40-60% cheaper. |
| |
| Args: |
| texts_with_metadata: List of dicts with keys: text, author_id, hw_id |
| |
| Returns: |
| List of dicts with keys: hw_id, image_base64, text, author_id, width, height |
| """ |
| if not texts_with_metadata: |
| return [] |
| |
| max_retries = settings.HANDWRITING_SERVICE_MAX_RETRIES |
| timeout = settings.HANDWRITING_SERVICE_TIMEOUT |
| |
| |
| |
| num_texts = len(texts_with_metadata) |
| batch_timeout = max(timeout, num_texts * 20 + 30) |
| |
| |
| headers = {"Content-Type": "application/json"} |
| if settings.RUNPOD_API_KEY: |
| headers["Authorization"] = f"Bearer {settings.RUNPOD_API_KEY}" |
| |
| print(f" Processing {num_texts} texts in ONE batch (1 worker activation)...") |
| |
| for attempt in range(max_retries): |
| try: |
| async with httpx.AsyncClient(timeout=batch_timeout) as client: |
| |
| runpod_request = { |
| "input": { |
| "texts": [ |
| { |
| "text": item["text"], |
| "author_id": item["author_id"], |
| "hw_id": item.get("hw_id", f"hw_{i}") |
| } |
| for i, item in enumerate(texts_with_metadata) |
| ], |
| "apply_blur": settings.HANDWRITING_APPLY_BLUR |
| } |
| } |
| |
| response = await client.post( |
| settings.HANDWRITING_SERVICE_URL, |
| json=runpod_request, |
| headers=headers |
| ) |
| response.raise_for_status() |
| |
| result = response.json() |
| |
| |
| |
| job_status = result.get("status") |
| |
| if job_status == "IN_PROGRESS": |
| |
| |
| job_id = result.get("id") |
| if not job_id: |
| raise Exception("RunPod job IN_PROGRESS but no job ID provided") |
| |
| print(f" ⏳ Job {job_id} still processing, polling status...") |
| |
| |
| |
| |
| base_url = settings.HANDWRITING_SERVICE_URL.replace("/runsync", "") |
| status_url = f"{base_url}/status/{job_id}" |
| |
| |
| max_polls = 30 |
| poll_delay = 5 |
| |
| for poll_attempt in range(max_polls): |
| await asyncio.sleep(poll_delay) |
| |
| status_response = await client.get(status_url, headers=headers) |
| status_response.raise_for_status() |
| result = status_response.json() |
| |
| job_status = result.get("status") |
| print(f" ⏳ Poll {poll_attempt + 1}/{max_polls}: {job_status}") |
| |
| if job_status == "COMPLETED": |
| print(f" ✅ Job completed after {poll_attempt + 1} polls") |
| break |
| elif job_status == "FAILED": |
| raise Exception(f"RunPod job failed: {result.get('error', 'Unknown error')}") |
| elif job_status not in ["IN_PROGRESS", "IN_QUEUE"]: |
| raise Exception(f"Unknown job status: {job_status}") |
| |
| |
| poll_delay = min(poll_delay + 1, 10) |
| else: |
| raise Exception(f"Job did not complete after {max_polls} status checks") |
| |
| if job_status != "COMPLETED": |
| raise Exception(f"RunPod job not completed: {job_status}") |
| |
| output = result.get("output", {}) |
| if "error" in output: |
| raise Exception(f"RunPod error: {output['error']}") |
| |
| |
| images = output.get("images", []) |
| if not images: |
| raise Exception("No images in batch response") |
| |
| |
| all_results = [ |
| { |
| "hw_id": img.get("hw_id"), |
| "text": img.get("text"), |
| "author_id": img.get("author_id"), |
| "image_base64": img.get("image_base64"), |
| "width": img.get("width"), |
| "height": img.get("height") |
| } |
| for img in images |
| ] |
| |
| print(f" → Batch complete: {len(all_results)}/{num_texts} texts generated successfully") |
| return all_results |
| |
| except httpx.TimeoutException as e: |
| if attempt < max_retries - 1: |
| wait_time = 10 * (attempt + 1) |
| print(f" ⚠️ Timeout on attempt {attempt + 1}/{max_retries}, retrying in {wait_time}s...") |
| await asyncio.sleep(wait_time) |
| continue |
| else: |
| print(f" ❌ Batch failed after {max_retries} retries: {e}") |
| return [] |
| |
| except Exception as e: |
| if attempt < max_retries - 1: |
| wait_time = 5 * (attempt + 1) |
| print(f" ⚠️ Error on attempt {attempt + 1}/{max_retries}: {e}, retrying in {wait_time}s...") |
| await asyncio.sleep(wait_time) |
| continue |
| else: |
| print(f" ❌ Batch failed: {e}") |
| return [] |
| |
| return [] |
|
|
|
|
| async def generate_visual_element_images( |
| visual_elements: list[dict], |
| seed: Optional[int] = None, |
| assets_dir: Optional[pathlib.Path] = None |
| ) -> dict: |
| """ |
| Generate visual element images (stamps, logos, barcodes, photos, figures). |
| |
| Args: |
| visual_elements: List of visual element definitions with type, content, rect |
| seed: Random seed for reproducible selection (default: None) |
| |
| Returns: |
| Dict {ve_id: base64_png} of generated images |
| """ |
| import random |
| import base64 |
| import io |
| from pathlib import Path |
| |
| if seed is not None: |
| random.seed(seed) |
| |
| visual_element_images = {} |
| |
| |
| logo_prefabs = None |
| photo_prefabs = None |
| figure_prefabs = None |
| |
| def get_logo_prefabs(): |
| nonlocal logo_prefabs |
| if logo_prefabs is None: |
| logo_dir = ENV.VISUAL_ELEMENT_PREFABS_DIR / "logo" |
| logo_prefabs = list(logo_dir.glob("*.png")) + list(logo_dir.glob("*.jpg")) |
| return logo_prefabs |
| |
| def get_photo_prefabs(): |
| nonlocal photo_prefabs |
| if photo_prefabs is None: |
| photo_dir = ENV.VISUAL_ELEMENT_PREFABS_DIR / "photo" |
| photo_prefabs = list(photo_dir.glob("*.png")) + list(photo_dir.glob("*.jpg")) |
| return photo_prefabs |
| |
| def get_figure_prefabs(): |
| nonlocal figure_prefabs |
| if figure_prefabs is None: |
| figure_dir = ENV.VISUAL_ELEMENT_PREFABS_DIR / "figure" |
| figure_prefabs = list(figure_dir.glob("*.png")) + list(figure_dir.glob("*.jpg")) |
| return figure_prefabs |
| |
| for ve in visual_elements: |
| ve_id = ve.get('id', 'unknown') |
| ve_type = ve.get('type', 'unknown') |
| content = ve.get('content', '') |
| rect = ve.get('rect', {}) |
| width = rect.get('width', 100) |
| height = rect.get('height', 100) |
| rotation = ve.get('rotation', 0) |
| |
| try: |
| img = None |
| |
| if ve_type == 'stamp': |
| |
| if assets_dir: |
| stamp_files = list(assets_dir.glob("stamp_*")) |
| if stamp_files: |
| selected_stamp = random.choice(stamp_files) |
| img = Image.open(selected_stamp).convert("RGBA") |
| |
| if not img: |
| img = create_stamp( |
| text=content if content else "STAMP", |
| width=width, |
| height=height, |
| rot_angle=None |
| ) |
| |
| elif ve_type == 'logo': |
| |
| if assets_dir: |
| logo_files = list(assets_dir.glob("logo_*")) |
| if logo_files: |
| selected_logo = random.choice(logo_files) |
| img = Image.open(selected_logo).convert("RGBA") |
| |
| if not img: |
| logos = get_logo_prefabs() |
| if logos: |
| selected_logo = random.choice(logos) |
| img = Image.open(selected_logo).convert("RGBA") |
| |
| elif ve_type == 'barcode': |
| |
| try: |
| from barcode import Code128 |
| from barcode.writer import ImageWriter |
| |
| |
| barcode_content = content.strip() if content and content.strip().isdigit() else str(random.randint(100000000000, 999999999999)) |
| |
| |
| writer = ImageWriter() |
| writer.set_options({ |
| "module_width": 0.3, |
| "module_height": 15.0, |
| "quiet_zone": 6.5, |
| "font_size": 7, |
| "text_distance": 5, |
| "background": "rgba(255, 255, 255, 0)", |
| "foreground": "black", |
| }) |
| |
| code128 = Code128(barcode_content, writer=writer) |
| buffer = io.BytesIO() |
| code128.write(buffer, options={"format": "PNG"}) |
| buffer.seek(0) |
| img = Image.open(buffer).convert("RGBA") |
| |
| except ImportError: |
| print(f" ⚠ 'python-barcode' not installed, skipping barcode {ve_id}") |
| except Exception as e: |
| print(f" ⚠ Barcode generation failed for {ve_id}: {e}") |
| |
| elif ve_type == 'photo': |
| |
| if assets_dir: |
| photo_files = list(assets_dir.glob("photo_*")) |
| if photo_files: |
| selected_photo = random.choice(photo_files) |
| img = Image.open(selected_photo).convert("RGBA") |
| |
| if not img: |
| photos = get_photo_prefabs() |
| if photos: |
| selected_photo = random.choice(photos) |
| img = Image.open(selected_photo).convert("RGBA") |
| |
| elif ve_type in ['figure', 'chart', 'diagram']: |
| |
| if assets_dir: |
| figure_files = list(assets_dir.glob("figure_*")) |
| if figure_files: |
| selected_figure = random.choice(figure_files) |
| img = Image.open(selected_figure).convert("RGBA") |
| |
| if not img: |
| figures = get_figure_prefabs() |
| if figures: |
| selected_figure = random.choice(figures) |
| img = Image.open(selected_figure).convert("RGBA") |
| |
| |
| if img: |
| buffer = io.BytesIO() |
| img.save(buffer, format="PNG") |
| buffer.seek(0) |
| img_b64 = base64.b64encode(buffer.read()).decode('utf-8') |
| visual_element_images[ve_id] = img_b64 |
| |
| except Exception as e: |
| print(f" ⚠ Failed to generate visual element {ve_id} (type: {ve_type}): {e}") |
| continue |
| |
| return visual_element_images |
|
|
|
|
| async def process_stage3_complete( |
| pdf_path: pathlib.Path, |
| geometries: list[dict], |
| ground_truth: dict, |
| bboxes_raw: list[dict], |
| page_width_mm: float, |
| page_height_mm: float, |
| enable_handwriting: bool = False, |
| handwriting_ratio: float = 0.5, |
| enable_visual_elements: bool = False, |
| visual_element_types: list[str] = None, |
| seed: Optional[int] = None, |
| assets_dir: Optional[pathlib.Path] = None |
| ) -> tuple[str, list[dict], list[dict], dict, dict, pathlib.Path | None, pathlib.Path | None]: |
| """ |
| Process complete Stage 3 pipeline (stages 07-11) using browser-extracted geometries. |
| - Extract handwriting definitions from geometries (from DOM, not HTML parsing) |
| - Extract visual element definitions from geometries |
| - Generate handwriting images (via EC2 service if enabled) |
| - Create visual element images |
| - Render second-pass PDF with handwriting and visual elements |
| - Convert final PDF to base64 image |
| |
| Args: |
| geometries: List of element geometries extracted from browser DOM |
| |
| Returns: |
| tuple: (final_image_base64, handwriting_regions, visual_elements, handwriting_images, visual_element_images, pdf_with_handwriting_path, pdf_final_path) |
| - final_image_base64: Base64 PNG of final document |
| - handwriting_regions: List of handwriting metadata dicts |
| - visual_elements: List of visual element metadata dicts |
| - handwriting_images: Dict {hw_id: base64_png} for individual tokens |
| - visual_element_images: Dict {ve_id: base64_png} for individual elements |
| - pdf_with_handwriting_path: Path to PDF after handwriting insertion (or None) |
| - pdf_final_path: Path to final PDF after all modifications (or None) |
| """ |
| import random |
| import base64 |
| import fitz |
| |
| handwriting_regions = [] |
| visual_elements = [] |
| |
| print(f" 🔍 Processing {len(geometries)} geometries from DOM") |
| |
| |
| if enable_handwriting: |
| |
| from docgenie.generation.models import OCRBox |
| from docgenie.generation.constants import BBOX_TO_GEO_MATCHING_THRESHOLD |
| from docgenie.generation.utils.bboxes import is_in_rect |
| |
| |
| word_bboxes = [] |
| for bbox_dict in bboxes_raw: |
| word_bboxes.append(OCRBox( |
| x0=bbox_dict['x'], |
| y0=bbox_dict['y'], |
| x2=bbox_dict['x'] + bbox_dict['width'], |
| y2=bbox_dict['y'] + bbox_dict['height'], |
| text=bbox_dict['text'], |
| block_no=bbox_dict.get('block_no', 0), |
| line_no=bbox_dict.get('line_no', 0), |
| word_no=bbox_dict.get('word_no', 0) |
| )) |
| |
| |
| hw_geometries = [g for g in geometries if "handwriting" in g.get("selectorTypes", [])] |
| |
| print(f" - Found {len(hw_geometries)} handwriting geometries") |
| |
| taken_bbox_indices = set() |
| |
| for i, geo in enumerate(hw_geometries): |
| classes_str = geo.get('classes', '') |
| classes = classes_str.split() if classes_str else [] |
| |
| |
| other_classes = [c for c in classes if c != 'handwritten'] |
| valid_author_ids = [c for c in other_classes if c.startswith("author")] |
| author_id = valid_author_ids[0] if valid_author_ids else None |
| |
| |
| if seed is not None: |
| random.seed(seed + i) |
| if random.random() > handwriting_ratio: |
| continue |
| |
| text_content = geo.get('text', '').strip() |
| if not text_content: |
| continue |
| |
| is_signature = 'signature' in classes |
| |
| |
| |
| |
| rect_browser = geo.get('rect', {}) |
| dpi_scale = 72.0 / 96.0 |
| rect = { |
| 'x': rect_browser.get('x', 0) * dpi_scale, |
| 'y': rect_browser.get('y', 0) * dpi_scale, |
| 'width': rect_browser.get('width', 0) * dpi_scale, |
| 'height': rect_browser.get('height', 0) * dpi_scale |
| } |
| |
| |
| words = text_content.split() |
| n = len(words) |
| matched_bboxes = [] |
| |
| for j in range(len(word_bboxes) - n + 1): |
| slice_texts = [b.text for b in word_bboxes[j : j + n]] |
| if slice_texts == words: |
| start, stop = j, j + n |
| if (start, stop) not in taken_bbox_indices: |
| |
| start_in_rect = is_in_rect( |
| rect=rect, |
| bbox=word_bboxes[start], |
| threshold=BBOX_TO_GEO_MATCHING_THRESHOLD |
| ) |
| stop_in_rect = is_in_rect( |
| rect=rect, |
| bbox=word_bboxes[stop - 1], |
| threshold=BBOX_TO_GEO_MATCHING_THRESHOLD |
| ) |
| if start_in_rect and stop_in_rect: |
| matched_bboxes = word_bboxes[start:stop] |
| taken_bbox_indices.add((start, stop)) |
| break |
| |
| if not matched_bboxes: |
| print(f" - ⚠️ No bbox match for hw{i}: '{text_content[:30]}'") |
| continue |
| |
| handwriting_regions.append({ |
| 'id': f'hw{i}', |
| 'text': text_content, |
| 'author_id': author_id, |
| 'is_signature': is_signature, |
| 'rect': rect, |
| 'bboxes': [b.as_string() for b in matched_bboxes], |
| 'classes': classes_str |
| }) |
| |
| print(f" - Selected {len(handwriting_regions)} handwriting regions (ratio: {handwriting_ratio})") |
| |
| |
| if enable_visual_elements: |
| |
| ve_geometries = [g for g in geometries if "visual_element" in g.get("selectorTypes", [])] |
| |
| print(f" - Found {len(ve_geometries)} visual element geometries") |
| |
| for i, geo in enumerate(ve_geometries): |
| data_type = geo.get('dataPlaceholder', '') |
| data_content = geo.get('dataContent', '') |
| |
| |
| normalized_type = VISUAL_ELEMENT_TYPE_SYNONYMS.get(data_type, data_type) |
| |
| |
| if visual_element_types and normalized_type not in visual_element_types: |
| print(f" ⚠️ Filtered out visual element type '{data_type}' (normalized to '{normalized_type}', not in requested types: {visual_element_types})") |
| continue |
| |
| |
| rect_px = geo.get('rect', {}) |
| px_to_mm = 25.4 / 96 |
| rect = { |
| 'x': rect_px.get('x', 0) * px_to_mm, |
| 'y': rect_px.get('y', 0) * px_to_mm, |
| 'width': rect_px.get('width', 0) * px_to_mm, |
| 'height': rect_px.get('height', 0) * px_to_mm |
| } |
| |
| |
| rotation = 0 |
| style = geo.get('style', '') |
| if style and 'rotate' in style: |
| rotation = extract_rotation_from_style(style) |
| |
| visual_elements.append({ |
| 'id': f've{i}', |
| 'type': normalized_type, |
| 'content': data_content, |
| 'rect': rect, |
| 'rotation': rotation |
| }) |
| |
| print(f" - Selected {len(visual_elements)} visual elements") |
| |
| |
| handwriting_images = {} |
| |
| |
| print(f"\n 🔍 DEBUG - Handwriting Service Check:") |
| print(f" - enable_handwriting: {enable_handwriting}") |
| print(f" - handwriting_regions count: {len(handwriting_regions)}") |
| print(f" - HANDWRITING_SERVICE_ENABLED: {settings.HANDWRITING_SERVICE_ENABLED}") |
| print(f" - HANDWRITING_SERVICE_URL: {settings.HANDWRITING_SERVICE_URL}") |
| |
| if enable_handwriting and handwriting_regions and settings.HANDWRITING_SERVICE_ENABLED: |
| print(f" ✅ Handwriting service check PASSED - preparing batch request...") |
| |
| |
| |
| from docgenie.generation.constants import WRITER_STYLES |
| |
| |
| def map_author_to_style_id(author_id_str: str, seed_val: Optional[int] = None) -> int: |
| """ |
| Map author ID string (like 'author1') to numeric style ID (0-656). |
| Matches original pipeline's style selection logic. |
| """ |
| if not author_id_str or not author_id_str.startswith('author'): |
| |
| return random.choice(WRITER_STYLES) |
| |
| try: |
| |
| author_num = int(author_id_str.replace('author', '')) |
| |
| style_idx = author_num % len(WRITER_STYLES) |
| return WRITER_STYLES[style_idx] |
| except ValueError: |
| |
| return random.choice(WRITER_STYLES) |
| |
| |
| texts_to_generate = [] |
| for i, hw_region in enumerate(handwriting_regions): |
| author_id_str = hw_region.get('author_id') |
| text = hw_region.get('text', '') |
| print(f" - Region {i+1}: author_id='{author_id_str}', text='{text[:30]}...'") |
| |
| |
| if author_id_str is not None: |
| |
| style_id = map_author_to_style_id(author_id_str, seed) |
| print(f" → Mapped to style_id={style_id}") |
| |
| |
| bboxes_str = hw_region.get('bboxes', []) |
| if not bboxes_str: |
| print(f" → ⚠️ Skipped (no bboxes)") |
| continue |
| |
| |
| from collections import defaultdict |
| from docgenie.generation.utils.bboxes import read_syn_dataset_bbox_str |
| |
| grouped_bboxes = defaultdict(list) |
| for bbox_str in bboxes_str: |
| bbox = read_syn_dataset_bbox_str(bbox_str) |
| grouped_bboxes[(bbox.block_no, bbox.line_no)].append(bbox) |
| |
| |
| for (block_no, line_no), bbox_group in grouped_bboxes.items(): |
| |
| for word_idx, bbox in enumerate(bbox_group): |
| word_text = bbox.text |
| |
| |
| filtered_text = ''.join(c for c in word_text if c.isalpha()) |
| |
| |
| if not filtered_text: |
| continue |
| |
| texts_to_generate.append({ |
| 'text': filtered_text, |
| 'author_id': style_id, |
| 'hw_id': f"{hw_region['id']}_b{block_no}_l{line_no}_w{word_idx}" |
| }) |
| |
| print(f" → {len(grouped_bboxes)} block/line groups") |
| else: |
| print(f" → ⚠️ Skipped (no author_id)") |
| |
| print(f" - Prepared {len(texts_to_generate)} texts for generation") |
| |
| if texts_to_generate: |
| try: |
| print(f" - Calling RunPod handwriting service at {settings.HANDWRITING_SERVICE_URL}...") |
| |
| results = await call_handwriting_service_batch(texts_to_generate) |
| |
| print(f" - ✅ Received {len(results)} handwriting images") |
| |
| |
| for result in results: |
| handwriting_images[result['hw_id']] = result['image_base64'] |
| |
| except Exception as e: |
| print(f" - ❌ Handwriting service call failed: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| raise Exception(f"Handwriting generation failed: {e}") from e |
| else: |
| print(f" - ⚠️ No texts to generate (all regions missing author_id)") |
| else: |
| print(f" ❌ Handwriting service check FAILED - skipping generation") |
| |
| |
| visual_element_images = {} |
| if enable_visual_elements and visual_elements: |
| try: |
| visual_element_images = await generate_visual_element_images( |
| visual_elements, |
| seed=seed, |
| assets_dir=assets_dir |
| ) |
| print(f" ✓ Generated {len(visual_element_images)} visual element images") |
| except Exception as e: |
| print(f" ⚠ Visual element generation failed: {e}") |
| |
| |
| |
| doc = fitz.open(pdf_path) |
| page = doc[0] |
| pdf_with_handwriting_path = None |
| pdf_final_path = None |
| |
| if handwriting_images: |
| print(f" 🖊️ Inserting {len(handwriting_images)} handwriting images into PDF...") |
| |
| from docgenie.generation.constants import ( |
| FIXED_HANDWRITING_X_OFFSET, |
| MAX_HANDWRITING_RAND_X_OFFSET_LEFT, |
| MAX_HANDWRITING_RAND_X_OFFSET_RIGHT, |
| MAX_HANDWRITING_RAND_Y_OFFSET_UP, |
| MAX_HANDWRITING_RAND_Y_OFFSET_DOWN, |
| PIPELINE_04_3_SCALE_UP_FACTOR |
| ) |
| |
| scale_up = PIPELINE_04_3_SCALE_UP_FACTOR |
| |
| from docgenie.generation.utils.bboxes import read_syn_dataset_bbox_str |
| |
| |
| |
| print(f" - Whitening out original text regions...") |
| for hw_region in handwriting_regions: |
| bboxes_str = hw_region.get('bboxes', []) |
| if not bboxes_str: |
| continue |
| |
| |
| for bbox_str in bboxes_str: |
| bbox = read_syn_dataset_bbox_str(bbox_str) |
| |
| rect = fitz.Rect(bbox.x0, bbox.y0, bbox.x2, bbox.y2) |
| page.draw_rect(rect, color=(1, 1, 1), fill=(1, 1, 1)) |
| |
| print(f" - Inserting handwriting images...") |
| |
| |
| for hw_region in handwriting_regions: |
| hw_id = hw_region['id'] |
| rect = hw_region['rect'] |
| bboxes_str = hw_region.get('bboxes', []) |
| |
| if not bboxes_str: |
| continue |
| |
| |
| from collections import defaultdict |
| grouped_bboxes = defaultdict(list) |
| for bbox_str in bboxes_str: |
| bbox = read_syn_dataset_bbox_str(bbox_str) |
| grouped_bboxes[(bbox.block_no, bbox.line_no)].append(bbox) |
| |
| |
| for (block_no, line_no), bbox_group in grouped_bboxes.items(): |
| for word_idx, bbox in enumerate(bbox_group): |
| img_id = f"{hw_id}_b{block_no}_l{line_no}_w{word_idx}" |
| |
| if img_id not in handwriting_images: |
| continue |
| |
| try: |
| |
| img_data = base64.b64decode(handwriting_images[img_id]) |
| img = Image.open(BytesIO(img_data)) |
| |
| |
| bbox_w = bbox.x2 - bbox.x0 |
| bbox_h = bbox.y2 - bbox.y0 |
| |
| |
| iw, ih = img.size |
| scale = min(bbox_w / iw, bbox_h / ih) |
| new_w = int(iw * scale * scale_up) |
| new_h = int(ih * scale * scale_up) |
| |
| img_resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS).convert("RGBA") |
| |
| |
| img_bytes_io = BytesIO() |
| img_resized.save(img_bytes_io, format="PNG") |
| img_bytes = img_bytes_io.getvalue() |
| |
| |
| y_padding = 50 |
| offset_x = random.randint( |
| -MAX_HANDWRITING_RAND_X_OFFSET_LEFT, |
| MAX_HANDWRITING_RAND_X_OFFSET_RIGHT |
| ) + FIXED_HANDWRITING_X_OFFSET |
| offset_y = random.randint( |
| -MAX_HANDWRITING_RAND_Y_OFFSET_UP, |
| MAX_HANDWRITING_RAND_Y_OFFSET_DOWN |
| ) |
| |
| |
| x0_pos = bbox.x0 + offset_x |
| y0_pos = bbox.y0 + offset_y - y_padding |
| x2_pos = min(x0_pos + img_resized.width / scale_up, bbox.x2) + offset_x |
| y2_pos = min(y0_pos + img_resized.height / scale_up, bbox.y2) + offset_y + 2 * y_padding |
| |
| |
| rect_fitz = fitz.Rect(x0_pos, y0_pos, x2_pos, y2_pos) |
| page.insert_image(rect_fitz, stream=img_bytes) |
| |
| print(f" - ✓ Inserted {img_id} at ({x0_pos:.1f}, {y0_pos:.1f})") |
| |
| except Exception as e: |
| print(f" - ⚠️ Failed to insert {img_id}: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| print(f" ✓ Handwriting insertion complete") |
| |
| |
| pdf_with_handwriting_path = pdf_path.parent / f"{pdf_path.stem}_with_handwriting.pdf" |
| doc.save(pdf_with_handwriting_path) |
| print(f" - Saved PDF with handwriting: {pdf_with_handwriting_path.name}") |
| doc.close() |
| |
| |
| doc = fitz.open(pdf_with_handwriting_path) |
| page = doc[0] |
| |
| |
| if visual_element_images and visual_elements: |
| print(f" 🎨 Inserting {len(visual_element_images)} visual elements into PDF...") |
| |
| from docgenie.generation.constants import PIPELINE_04_3_SCALE_UP_FACTOR |
| scale_up = PIPELINE_04_3_SCALE_UP_FACTOR |
| |
| for ve in visual_elements: |
| ve_id = ve['id'] |
| |
| if ve_id not in visual_element_images: |
| print(f" - ⚠️ Skipping {ve_id}: image not generated") |
| continue |
| |
| try: |
| |
| img_data = base64.b64decode(visual_element_images[ve_id]) |
| img = Image.open(BytesIO(img_data)) |
| |
| |
| rect = ve['rect'] |
| bbox_width = rect['width'] |
| bbox_height = rect['height'] |
| |
| |
| mm_to_pt = 72 / 25.4 |
| bbox_w_pt = bbox_width * mm_to_pt |
| bbox_h_pt = bbox_height * mm_to_pt |
| x0_pt = rect['x'] * mm_to_pt |
| y0_pt = rect['y'] * mm_to_pt |
| |
| |
| iw, ih = img.size |
| scale = min(bbox_w_pt / iw, bbox_h_pt / ih) |
| new_w = int(iw * scale * scale_up) |
| new_h = int(ih * scale * scale_up) |
| |
| img_resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS).convert("RGBA") |
| |
| |
| final_img = Image.new( |
| "RGBA", |
| (int(bbox_w_pt * scale_up), int(bbox_h_pt * scale_up)), |
| (255, 255, 255, 0) |
| ) |
| |
| |
| offset_x = (int(bbox_w_pt * scale_up) - new_w) // 2 |
| offset_y = (int(bbox_h_pt * scale_up) - new_h) // 2 |
| final_img.paste(img_resized, (offset_x, offset_y), mask=img_resized) |
| |
| |
| img_bytes_io = BytesIO() |
| final_img.save(img_bytes_io, format="PNG") |
| img_bytes = img_bytes_io.getvalue() |
| |
| |
| rect_fitz = fitz.Rect(x0_pt, y0_pt, x0_pt + bbox_w_pt, y0_pt + bbox_h_pt) |
| page.insert_image(rect_fitz, stream=img_bytes) |
| |
| print(f" - ✓ Inserted {ve_id} ({ve['type']}) at ({x0_pt:.1f}, {y0_pt:.1f})") |
| |
| except Exception as e: |
| print(f" - ⚠️ Failed to insert {ve_id}: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| print(f" ✓ Visual element insertion complete") |
| |
| |
| |
| |
| if pdf_with_handwriting_path: |
| |
| pdf_final_path = pdf_path.parent / f"{pdf_path.stem}_final.pdf" |
| doc.save(pdf_final_path) |
| print(f" - Saved final PDF (with handwriting + visual elements): {pdf_final_path.name}") |
| else: |
| |
| pdf_with_ve_only = pdf_path.parent / f"{pdf_path.stem}_with_visual_elements.pdf" |
| doc.save(pdf_with_ve_only) |
| print(f" - Saved PDF with visual elements: {pdf_with_ve_only.name}") |
| pdf_final_path = pdf_with_ve_only |
| |
| doc.close() |
| |
| |
| doc = fitz.open(pdf_final_path) |
| page = doc[0] |
| |
| |
| |
| pix = page.get_pixmap(matrix=fitz.Matrix(3, 3)) |
| img_bytes = pix.tobytes("png") |
| |
| |
| final_image_b64 = base64.b64encode(img_bytes).decode('utf-8') |
| |
| doc.close() |
| |
| |
| return final_image_b64, handwriting_regions, visual_elements, handwriting_images, visual_element_images, pdf_with_handwriting_path, pdf_final_path |
|
|
|
|
| def extract_rect_from_style(style: str, page_width_mm: float, page_height_mm: float) -> dict: |
| """Extract position and dimensions from inline CSS style.""" |
| import re |
| |
| rect = {'x': 0, 'y': 0, 'width': 0, 'height': 0} |
| |
| |
| for prop in style.split(';'): |
| if ':' not in prop: |
| continue |
| key, value = prop.split(':', 1) |
| key = key.strip().lower() |
| value = value.strip() |
| |
| |
| match = re.match(r'([-\d.]+)(mm|cm|px)?', value) |
| if not match: |
| continue |
| |
| num_val = float(match.group(1)) |
| unit = match.group(2) or 'mm' |
| |
| |
| if unit == 'cm': |
| num_val *= 10 |
| elif unit == 'px': |
| num_val *= 0.2645833333 |
| |
| |
| if key in ('left', 'x'): |
| rect['x'] = num_val |
| elif key in ('top', 'y'): |
| rect['y'] = num_val |
| elif key == 'width': |
| rect['width'] = num_val |
| elif key == 'height': |
| rect['height'] = num_val |
| |
| return rect |
|
|
|
|
| def extract_rotation_from_style(style: str) -> float: |
| """Extract 2D rotation angle from CSS transform property.""" |
| import re |
| |
| match = re.search(r'rotate\(\s*([-+]?\d*\.?\d+)\s*deg\s*\)', style) |
| if match: |
| return float(match.group(1)) |
| return 0.0 |
|
|
| |
|
|
| def run_local_tesseract_ocr(image: Image.Image) -> dict: |
| """ |
| Run Tesseract OCR locally on image. |
| |
| Args: |
| image: PIL Image to OCR |
| |
| Returns: |
| dict: OCR results in Microsoft OCR format |
| """ |
| try: |
| import pytesseract |
| |
| |
| data = pytesseract.image_to_data( |
| image, |
| lang=settings.OCR_TESSERACT_LANG, |
| config=settings.OCR_TESSERACT_CONFIG, |
| output_type=pytesseract.Output.DICT |
| ) |
| |
| |
| words = [] |
| for i in range(len(data['text'])): |
| text = data['text'][i].strip() |
| if text: |
| words.append({ |
| 'text': text, |
| 'confidence': float(data['conf'][i]) / 100.0 if data['conf'][i] != -1 else 0.0, |
| 'geo': [ |
| int(data['left'][i]), |
| int(data['top'][i]), |
| int(data['width'][i]), |
| int(data['height'][i]) |
| ] |
| }) |
| |
| return { |
| 'angle': 0, |
| 'imageWidth': image.width, |
| 'imageHeight': image.height, |
| 'words': words |
| } |
| |
| except ImportError: |
| raise RuntimeError( |
| "pytesseract not installed. Install with: uv pip install pytesseract\n" |
| "Also ensure Tesseract OCR is installed on your system:\n" |
| " Ubuntu/Debian: sudo apt-get install tesseract-ocr\n" |
| " macOS: brew install tesseract\n" |
| " Windows: Download from https://github.com/UB-Mannheim/tesseract/wiki" |
| ) |
| except Exception as e: |
| print(f"Error running local Tesseract OCR: {e}") |
| raise |
|
|
|
|
| async def call_ocr_service( |
| image: Image.Image, |
| ocr_url: str = None, |
| engine: str = "microsoft_di", |
| timeout: int = 30, |
| use_local: bool = None |
| ) -> dict: |
| """ |
| Call OCR service on image (Stage 15: Perform OCR). |
| |
| Supports both local Tesseract OCR and remote OCR services. |
| |
| Args: |
| image: PIL Image to OCR |
| ocr_url: OCR service URL (defaults to settings.OCR_SERVICE_URL) |
| engine: OCR engine to use |
| timeout: Request timeout in seconds |
| use_local: Force local/remote mode (None = use settings.OCR_USE_LOCAL) |
| |
| Returns: |
| dict: OCR results in Microsoft OCR format |
| """ |
| |
| if use_local is None: |
| use_local = settings.OCR_USE_LOCAL |
| |
| |
| if use_local: |
| print(" Using local Tesseract OCR...") |
| return run_local_tesseract_ocr(image) |
| |
| |
| if ocr_url is None: |
| ocr_url = settings.OCR_SERVICE_URL |
| |
| try: |
| |
| buffer = BytesIO() |
| image.save(buffer, format="PNG") |
| buffer.seek(0) |
| image_bytes = buffer.getvalue() |
| |
| |
| endpoint = f"{ocr_url}/v1/sync/ocr/{engine}" |
| |
| async with httpx.AsyncClient(timeout=timeout) as client: |
| files = {'image': image_bytes, 'type': 'image/png'} |
| headers = {'accept': 'application/json'} |
| |
| response = await client.post(endpoint, headers=headers, files=files) |
| response.raise_for_status() |
| |
| data = response.json() |
| |
| |
| if 'ocr' in data and 'pages' in data['ocr'] and len(data['ocr']['pages']) > 0: |
| return data['ocr']['pages'][0] |
| else: |
| raise ValueError("Invalid OCR response format") |
| |
| except Exception as e: |
| print(f"Error calling OCR service: {e}") |
| raise |
|
|
|
|
| async def render_pdf_to_image( |
| pdf_path: pathlib.Path, |
| dpi: int = 300 |
| ) -> tuple[Image.Image, str]: |
| """ |
| Convert PDF to high-quality image (Stage 14: Render Image). |
| |
| Uses pdf2image (poppler) for high-quality conversion matching original pipeline. |
| |
| Args: |
| pdf_path: Path to PDF file |
| dpi: DPI for rendering (default: 300, matching pipeline constant) |
| |
| Returns: |
| tuple: (PIL Image, base64-encoded PNG string) |
| """ |
| try: |
| |
| |
| images = convert_from_path(pdf_path, dpi=dpi) |
| |
| if not images: |
| raise ValueError("PDF conversion resulted in no images") |
| |
| if len(images) > 1: |
| print(f"Warning: PDF has {len(images)} pages, using first page only") |
| |
| img = images[0] |
| |
| |
| buffer = BytesIO() |
| img.save(buffer, format="PNG") |
| buffer.seek(0) |
| img_base64 = base64.b64encode(buffer.read()).decode('utf-8') |
| |
| return img, img_base64 |
| |
| except Exception as e: |
| print(f"Error converting PDF to image: {e}") |
| raise |
|
|
|
|
| def convert_ocr_to_api_format(ocr_page: dict) -> dict: |
| """ |
| Convert Microsoft OCR format to API OCRResult schema. |
| |
| Args: |
| ocr_page: OCR page result from Microsoft OCR service |
| |
| Returns: |
| dict: OCR results in API format |
| """ |
| words = [] |
| for word_data in ocr_page.get('words', []): |
| geo = word_data['geo'] |
| words.append({ |
| 'text': word_data['text'], |
| 'confidence': word_data['confidence'], |
| 'x': geo[0], |
| 'y': geo[1], |
| 'width': geo[2], |
| 'height': geo[3] |
| }) |
| |
| lines = [] |
| for line_data in ocr_page.get('lines', []): |
| geo = line_data['geo'] |
| |
| |
| line_words = [] |
| |
| |
| |
| lines.append({ |
| 'text': line_data['text'], |
| 'confidence': line_data['confidence'], |
| 'x': geo[0], |
| 'y': geo[1], |
| 'width': geo[2], |
| 'height': geo[3], |
| 'words': line_words |
| }) |
| |
| return { |
| 'image_width': ocr_page['imageWidth'], |
| 'image_height': ocr_page['imageHeight'], |
| 'angle': ocr_page.get('angle', 0.0), |
| 'words': words, |
| 'lines': lines |
| } |
|
|
|
|
| async def process_stage4_ocr( |
| pdf_path: pathlib.Path, |
| enable_ocr: bool = False, |
| dpi: int = 300 |
| ) -> tuple[Optional[str], Optional[dict]]: |
| """ |
| Process Stage 4: Image Finalization & OCR. |
| |
| This corresponds to: |
| - pipeline_14: Render PDF to high-quality image |
| - pipeline_15: Perform OCR on final image |
| |
| Args: |
| pdf_path: Path to final PDF (after Stage 3 if enabled) |
| enable_ocr: Whether to run OCR |
| dpi: DPI for image rendering |
| |
| Returns: |
| tuple: (image_base64, ocr_results_dict) |
| """ |
| image_base64 = None |
| ocr_results = None |
| |
| try: |
| |
| img, image_base64 = await render_pdf_to_image(pdf_path, dpi=dpi) |
| print(f" ✓ Stage 14: Rendered image {img.size[0]}x{img.size[1]} @ {dpi} DPI") |
| |
| |
| if enable_ocr and settings.OCR_SERVICE_ENABLED: |
| try: |
| ocr_page = await call_ocr_service( |
| img, |
| timeout=settings.OCR_SERVICE_TIMEOUT |
| ) |
| |
| ocr_results = convert_ocr_to_api_format(ocr_page) |
| print(f" ✓ Stage 15: OCR complete - {len(ocr_results['words'])} words, {len(ocr_results['lines'])} lines") |
| |
| except Exception as e: |
| print(f" ⚠ Stage 15: OCR failed - {str(e)}") |
| |
| elif enable_ocr: |
| print(f" ⚠ Stage 15: OCR requested but service not enabled (OCR_SERVICE_ENABLED=false)") |
| |
| return image_base64, ocr_results |
| |
| except Exception as e: |
| print(f" ⚠ Stage 4 processing failed: {str(e)}") |
| return None, None |
|
|
|
|
| |
|
|
| async def normalize_bboxes_stage16( |
| document_id: str, |
| pdf_path: str, |
| ocr_results: Optional[Dict[str, Any]], |
| scale: str = "0-1" |
| ) -> Tuple[Optional[List[Dict]], Optional[List[Dict]]]: |
| """ |
| Stage 16: Normalize bounding boxes to [0,1] scale. |
| Reuses logic from pipeline_16_normalize_bboxes.py |
| |
| Args: |
| document_id: Unique document identifier |
| pdf_path: Path to PDF file |
| ocr_results: OCR results from Stage 15 |
| scale: Normalization scale ("0-1" or "0-1000") |
| |
| Returns: |
| Tuple of (word_level_bboxes, segment_level_bboxes) |
| """ |
| try: |
| print(f"\\n Stage 16: Normalizing bounding boxes...") |
| |
| if not ocr_results or not ocr_results.get('words'): |
| print(f" ⚠ Stage 16: No OCR results to normalize") |
| return None, None |
| |
| |
| img_w_px = ocr_results.get('image_width', 0) |
| img_h_px = ocr_results.get('image_height', 0) |
| |
| if img_w_px == 0 or img_h_px == 0: |
| print(f" ⚠ Stage 16: Invalid image dimensions") |
| return None, None |
| |
| |
| normalized_words = [] |
| for word in ocr_results.get('words', []): |
| |
| x0_norm = word['x'] / img_w_px |
| y0_norm = word['y'] / img_h_px |
| x2_norm = (word['x'] + word['width']) / img_w_px |
| y2_norm = (word['y'] + word['height']) / img_h_px |
| |
| |
| if scale == "0-1000": |
| x0_norm *= 1000 |
| y0_norm *= 1000 |
| x2_norm *= 1000 |
| y2_norm *= 1000 |
| |
| normalized_words.append({ |
| 'text': word['text'], |
| 'x0': x0_norm, |
| 'y0': y0_norm, |
| 'x2': x2_norm, |
| 'y2': y2_norm, |
| 'block_no': None, |
| 'line_no': None, |
| 'word_no': None |
| }) |
| |
| |
| normalized_segments = [] |
| for line in ocr_results.get('lines', []): |
| x0_norm = line['x'] / img_w_px |
| y0_norm = line['y'] / img_h_px |
| x2_norm = (line['x'] + line['width']) / img_w_px |
| y2_norm = (line['y'] + line['height']) / img_h_px |
| |
| if scale == "0-1000": |
| x0_norm *= 1000 |
| y0_norm *= 1000 |
| x2_norm *= 1000 |
| y2_norm *= 1000 |
| |
| normalized_segments.append({ |
| 'text': line['text'], |
| 'x0': x0_norm, |
| 'y0': y0_norm, |
| 'x2': x2_norm, |
| 'y2': y2_norm, |
| 'block_no': None, |
| 'line_no': None, |
| 'word_no': None |
| }) |
| |
| print(f" ✓ Stage 16: Normalized {len(normalized_words)} words, {len(normalized_segments)} segments") |
| return normalized_words, normalized_segments |
| |
| except Exception as e: |
| print(f" ⚠ Stage 16: BBox normalization failed - {str(e)}") |
| return None, None |
|
|
|
|
| async def verify_ground_truth_stage17( |
| document_id: str, |
| ground_truth: Optional[Dict], |
| layout_elements: Optional[List[Dict]], |
| similarity_cutoff: float = 0.8 |
| ) -> Optional[Dict]: |
| """ |
| Stage 17: Verify and prepare ground truth annotations. |
| Simplified version of pipeline_17_gt_preparation_verification.py |
| |
| Args: |
| document_id: Unique document identifier |
| ground_truth: Ground truth data from Stage 2 |
| layout_elements: Layout/visual elements |
| similarity_cutoff: Similarity threshold for fuzzy matching |
| |
| Returns: |
| GT verification result dict |
| """ |
| try: |
| print(f"\\n Stage 17: Verifying ground truth...") |
| |
| if not ground_truth: |
| print(f" ⚠ Stage 17: No ground truth to verify") |
| return { |
| 'passed': False, |
| 'skipped': True, |
| 'confirmed_keys': [], |
| 'similarities': [] |
| } |
| |
| |
| confirmed_keys = list(ground_truth.keys()) if isinstance(ground_truth, dict) else [] |
| |
| |
| valid_pairs = 0 |
| similarities = [] |
| |
| if isinstance(ground_truth, dict): |
| for question, answer in ground_truth.items(): |
| if question and answer and isinstance(question, str) and isinstance(answer, str): |
| valid_pairs += 1 |
| |
| similarities.append(1.0) |
| |
| passed = valid_pairs > 0 |
| |
| result = { |
| 'passed': passed, |
| 'skipped': False, |
| 'confirmed_keys': confirmed_keys, |
| 'similarities': similarities, |
| 'num_layout_elements': len(layout_elements) if layout_elements else 0, |
| 'valid_labels': True |
| } |
| |
| print(f" ✓ Stage 17: GT verification {'passed' if passed else 'failed'} - {valid_pairs} valid pairs") |
| return result |
| |
| except Exception as e: |
| print(f" ⚠ Stage 17: GT verification failed - {str(e)}") |
| return { |
| 'passed': False, |
| 'skipped': False, |
| 'confirmed_keys': [], |
| 'similarities': [] |
| } |
|
|
|
|
| async def analyze_document_stage18( |
| document_id: str, |
| has_handwriting: bool, |
| has_visual_elements: bool, |
| has_ocr: bool, |
| gt_verification: Optional[Dict], |
| page_count: int = 1 |
| ) -> Dict: |
| """ |
| Stage 18: Generate document analysis and statistics. |
| Simplified version of pipeline_18_analyze.py |
| |
| Args: |
| document_id: Unique document identifier |
| has_handwriting: Whether document has handwriting |
| has_visual_elements: Whether document has visual elements |
| has_ocr: Whether OCR was performed |
| gt_verification: GT verification results |
| page_count: Number of pages |
| |
| Returns: |
| Analysis statistics dict |
| """ |
| try: |
| print(f"\\n Stage 18: Analyzing document...") |
| |
| |
| errors = [] |
| if page_count != 1: |
| errors.append("is_multipage") |
| if not gt_verification or not gt_verification.get('passed'): |
| errors.append("gt_verification_failed") |
| if not has_ocr: |
| errors.append("missing_ocr") |
| |
| is_valid = len(errors) == 0 |
| |
| stats = { |
| 'total_documents': 1, |
| 'valid_documents': 1 if is_valid else 0, |
| 'error_counts': {error: 1 for error in errors}, |
| 'has_handwriting': 1 if has_handwriting else 0, |
| 'has_visual_elements': 1 if has_visual_elements else 0, |
| 'has_ocr': 1 if has_ocr else 0, |
| 'multipage_count': 1 if page_count != 1 else 0, |
| 'token_usage': None |
| } |
| |
| print(f" ✓ Stage 18: Analysis complete - {'valid' if is_valid else 'has errors'}") |
| return stats |
| |
| except Exception as e: |
| print(f" ⚠ Stage 18: Analysis failed - {str(e)}") |
| return { |
| 'total_documents': 1, |
| 'valid_documents': 0, |
| 'error_counts': {'analysis_error': 1}, |
| 'has_handwriting': 0, |
| 'has_visual_elements': 0, |
| 'has_ocr': 0, |
| 'multipage_count': 0 |
| } |
|
|
|
|
| async def create_debug_visualization_stage19( |
| document_id: str, |
| image_base64: Optional[str], |
| normalized_bboxes: Optional[List[Dict]], |
| show_text: bool = True, |
| bbox_color: Tuple[int, int, int] = (255, 0, 0) |
| ) -> Optional[Dict]: |
| """ |
| Stage 19: Create debug visualization with bbox overlays. |
| Simplified version of pipeline_19_create_debug_data.py |
| |
| Args: |
| document_id: Unique document identifier |
| image_base64: Base64-encoded image |
| normalized_bboxes: Normalized bounding boxes |
| show_text: Whether to show text labels |
| bbox_color: RGB color for bboxes |
| |
| Returns: |
| Debug visualization dict with overlay image |
| """ |
| try: |
| print(f"\\n Stage 19: Creating debug visualization...") |
| |
| if not image_base64 or not normalized_bboxes: |
| print(f" ⚠ Stage 19: Missing image or bboxes") |
| return None |
| |
| |
| img_data = base64.b64decode(image_base64) |
| img = Image.open(BytesIO(img_data)) |
| |
| |
| from PIL import ImageDraw, ImageFont |
| |
| |
| draw = ImageDraw.Draw(img) |
| img_w, img_h = img.size |
| |
| |
| num_drawn = 0 |
| for bbox in normalized_bboxes[:100]: |
| |
| x0 = bbox['x0'] * img_w |
| y0 = bbox['y0'] * img_h |
| x2 = bbox['x2'] * img_w |
| y2 = bbox['y2'] * img_h |
| |
| |
| draw.rectangle([x0, y0, x2, y2], outline=bbox_color, width=2) |
| |
| |
| if show_text and bbox.get('text'): |
| text = bbox['text'][:20] |
| try: |
| |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10) |
| except: |
| font = ImageFont.load_default() |
| draw.text((x0, y0 - 12), text, fill=bbox_color, font=font) |
| |
| num_drawn += 1 |
| |
| |
| buffer = BytesIO() |
| img.save(buffer, format="PNG") |
| overlay_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| |
| result = { |
| 'bbox_overlay_base64': overlay_base64, |
| 'visual_elements_overlay_base64': None, |
| 'handwriting_overlay_base64': None |
| } |
| |
| print(f" ✓ Stage 19: Debug visualization created - {num_drawn} boxes drawn") |
| return result |
| |
| except Exception as e: |
| print(f" ⚠ Stage 19: Debug visualization failed - {str(e)}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| async def process_stage5_complete( |
| document_id: str, |
| pdf_path: str, |
| image_base64: Optional[str], |
| ocr_results: Optional[Dict], |
| ground_truth: Optional[Dict], |
| has_handwriting: bool, |
| has_visual_elements: bool, |
| layout_elements: Optional[List[Dict]], |
| enable_bbox_normalization: bool = False, |
| enable_gt_verification: bool = False, |
| enable_analysis: bool = False, |
| enable_debug_visualization: bool = False, |
| ) -> Dict[str, Any]: |
| """ |
| Process Stage 5: Dataset Packaging (Stages 16-19). |
| |
| Args: |
| document_id: Unique document identifier |
| pdf_path: Path to PDF file |
| image_base64: Base64-encoded final image |
| ocr_results: OCR results from Stage 15 |
| ground_truth: Ground truth from Stage 2 |
| has_handwriting: Whether handwriting was generated |
| has_visual_elements: Whether visual elements were generated |
| layout_elements: Layout/visual element metadata |
| enable_*: Feature flags for each sub-stage |
| |
| Returns: |
| Dict with all Stage 5 results |
| """ |
| results = { |
| 'normalized_bboxes_word': None, |
| 'normalized_bboxes_segment': None, |
| 'gt_verification': None, |
| 'analysis_stats': None, |
| 'debug_visualization': None |
| } |
| |
| try: |
| print(f"\\n========== Stage 5: Dataset Packaging ==========") |
| |
| |
| if enable_bbox_normalization: |
| norm_words, norm_segments = await normalize_bboxes_stage16( |
| document_id=document_id, |
| pdf_path=pdf_path, |
| ocr_results=ocr_results, |
| scale=settings.BBOX_NORMALIZATION_SCALE |
| ) |
| results['normalized_bboxes_word'] = norm_words |
| results['normalized_bboxes_segment'] = norm_segments |
| |
| |
| if enable_gt_verification: |
| gt_verification = await verify_ground_truth_stage17( |
| document_id=document_id, |
| ground_truth=ground_truth, |
| layout_elements=layout_elements, |
| similarity_cutoff=settings.GT_VERIFICATION_SIMILARITY_CUTOFF |
| ) |
| results['gt_verification'] = gt_verification |
| |
| |
| if enable_analysis: |
| analysis_stats = await analyze_document_stage18( |
| document_id=document_id, |
| has_handwriting=has_handwriting, |
| has_visual_elements=has_visual_elements, |
| has_ocr=ocr_results is not None, |
| gt_verification=results.get('gt_verification'), |
| page_count=1 |
| ) |
| results['analysis_stats'] = analysis_stats |
| |
| |
| if enable_debug_visualization and image_base64: |
| |
| bboxes_for_viz = results.get('normalized_bboxes_word') or results.get('normalized_bboxes_segment') |
| |
| if bboxes_for_viz: |
| |
| color_str = settings.DEBUG_BBOX_COLOR_RGB |
| try: |
| r, g, b = map(int, color_str.split(',')) |
| bbox_color = (r, g, b) |
| except: |
| bbox_color = (255, 0, 0) |
| |
| debug_viz = await create_debug_visualization_stage19( |
| document_id=document_id, |
| image_base64=image_base64, |
| normalized_bboxes=bboxes_for_viz, |
| show_text=settings.DEBUG_SHOW_TEXT_IN_BBOX, |
| bbox_color=bbox_color |
| ) |
| results['debug_visualization'] = debug_viz |
| |
| print(f" ✓ Stages 16-18: Dataset packaging complete\\n") |
| return results |
| |
| except Exception as e: |
| print(f" ⚠ Stages 16-18 processing failed: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| return results |
|
|
|
|
| |
|
|
| async def export_to_msgpack( |
| document_id: str, |
| image_path: Optional[str], |
| image_base64: Optional[str], |
| words: List[str], |
| word_bboxes: List[List[float]], |
| segment_bboxes: Optional[List[List[float]]], |
| ground_truth: Optional[Dict], |
| output_path: pathlib.Path, |
| image_width: Optional[int] = None, |
| image_height: Optional[int] = None |
| ) -> pathlib.Path: |
| """ |
| Export document data to msgpack format. |
| |
| This creates a simple msgpack file containing the document data in a format |
| compatible with DocGenie's dataset infrastructure. |
| |
| Args: |
| document_id: Unique document identifier |
| image_path: Path to document image (if available) |
| image_base64: Base64-encoded image (if no image_path) |
| words: List of word strings |
| word_bboxes: Word-level bounding boxes (normalized [0,1]) |
| segment_bboxes: Segment-level bounding boxes (normalized [0,1]) |
| ground_truth: Ground truth annotations |
| output_path: Output msgpack file path |
| image_width: Image width in pixels |
| image_height: Image height in pixels |
| |
| Returns: |
| Path to created msgpack file |
| """ |
| try: |
| from datadings.writer import FileWriter |
| |
| print(f"\\n========== Msgpack Export ==========") |
| print(f" Exporting document {document_id} to msgpack format...") |
| |
| |
| doc_data = { |
| "key": document_id, |
| "sample_id": document_id, |
| "words": words, |
| "word_bboxes": word_bboxes, |
| } |
| |
| |
| if segment_bboxes: |
| doc_data["segment_level_bboxes"] = segment_bboxes |
| else: |
| |
| doc_data["segment_level_bboxes"] = word_bboxes |
| |
| |
| if image_width and image_height: |
| doc_data["image_width"] = image_width |
| doc_data["image_height"] = image_height |
| |
| |
| if image_path: |
| doc_data["image_file_path"] = str(image_path) |
| |
| |
| if ground_truth: |
| |
| if "label" in ground_truth: |
| doc_data["label"] = ground_truth["label"] |
| |
| |
| if "entities" in ground_truth: |
| entities = ground_truth["entities"] |
| if entities: |
| |
| word_labels = ["O"] * len(words) |
| |
| |
| for entity in entities: |
| entity_text = entity.get("text", "") |
| entity_label = entity.get("label", "ENTITY") |
| |
| |
| entity_words = entity_text.split() |
| for i, word in enumerate(words): |
| if word in entity_words: |
| word_labels[i] = f"B-{entity_label}" if i == 0 or word_labels[i-1] == "O" else f"I-{entity_label}" |
| |
| doc_data["word_labels"] = word_labels |
| |
| |
| if "questions" in ground_truth: |
| qa_pairs = [] |
| for qa in ground_truth["questions"]: |
| qa_pair = { |
| "question": qa.get("question", ""), |
| "answers": qa.get("answers", []), |
| "question_id": qa.get("id", "") |
| } |
| qa_pairs.append(qa_pair) |
| doc_data["qa_pairs"] = qa_pairs |
| |
| |
| if "layout_elements" in ground_truth: |
| layout_elements = ground_truth["layout_elements"] |
| annotated_objects = [] |
| for elem in layout_elements: |
| obj = { |
| "label": elem.get("label", "text"), |
| "bbox": elem.get("bbox", [0, 0, 1, 1]), |
| "score": elem.get("score", 1.0) |
| } |
| annotated_objects.append(obj) |
| doc_data["annotated_objects"] = annotated_objects |
| |
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| with FileWriter(output_path, overwrite=True) as writer: |
| writer.write(doc_data) |
| |
| print(f" ✓ Msgpack exported: {output_path}") |
| print(f" - Words: {len(words)}") |
| print(f" - Word BBoxes: {len(word_bboxes)}") |
| print(f" - Segment BBoxes: {len(doc_data['segment_level_bboxes'])}") |
| if "word_labels" in doc_data: |
| print(f" - Labels: {len(doc_data['word_labels'])}") |
| if "qa_pairs" in doc_data: |
| print(f" - QA Pairs: {len(doc_data['qa_pairs'])}") |
| |
| return output_path |
| |
| except ImportError: |
| print(f" ⚠ Warning: 'datadings' package not available. Msgpack export skipped.") |
| print(f" Install with: pip install datadings") |
| return None |
| except Exception as e: |
| print(f" ⚠ Msgpack export failed: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| def save_individual_tokens_to_disk( |
| handwriting_images: dict, |
| visual_element_images: dict, |
| output_dir: pathlib.Path, |
| doc_id: str |
| ) -> dict: |
| """ |
| Save individual handwriting tokens and visual elements to disk. |
| Used for 'dataset' and 'complete' output detail levels. |
| |
| Args: |
| handwriting_images: Dict {hw_id: base64_png} |
| visual_element_images: Dict {ve_id: base64_png} |
| output_dir: Base output directory |
| doc_id: Document ID for folder naming |
| |
| Returns: |
| dict with paths to saved files |
| """ |
| import base64 |
| |
| saved_files = { |
| 'handwriting_tokens': [], |
| 'visual_elements': [] |
| } |
| |
| |
| if handwriting_images: |
| hw_dir = output_dir / doc_id / "handwriting_tokens" |
| hw_dir.mkdir(parents=True, exist_ok=True) |
| |
| for hw_id, img_b64 in handwriting_images.items(): |
| img_bytes = base64.b64decode(img_b64) |
| img_path = hw_dir / f"{hw_id}.png" |
| img_path.write_bytes(img_bytes) |
| saved_files['handwriting_tokens'].append(str(img_path.relative_to(output_dir))) |
| |
| |
| if visual_element_images: |
| ve_dir = output_dir / doc_id / "visual_elements" |
| ve_dir.mkdir(parents=True, exist_ok=True) |
| |
| for ve_id, img_b64 in visual_element_images.items(): |
| img_bytes = base64.b64decode(img_b64) |
| img_path = ve_dir / f"{ve_id}.png" |
| img_path.write_bytes(img_bytes) |
| saved_files['visual_elements'].append(str(img_path.relative_to(output_dir))) |
| |
| return saved_files |
|
|
|
|
| def create_token_mapping_json( |
| handwriting_regions: list[dict], |
| handwriting_images: dict, |
| visual_elements: list[dict], |
| visual_element_images: dict |
| ) -> dict: |
| """ |
| Create mapping JSON for ML dataset creation. |
| Includes style IDs, positions, and image references. |
| |
| Args: |
| handwriting_regions: List of handwriting metadata |
| handwriting_images: Dict of handwriting images |
| visual_elements: List of visual element metadata |
| visual_element_images: Dict of visual element images |
| |
| Returns: |
| dict with complete token mapping |
| """ |
| mapping = { |
| 'handwriting': { |
| 'tokens': [], |
| 'total_count': len(handwriting_regions) |
| }, |
| 'visual_elements': { |
| 'items': [], |
| 'total_count': len(visual_elements) |
| } |
| } |
| |
| |
| for hw_region in handwriting_regions: |
| hw_id = hw_region.get('id', 'unknown') |
| token_info = { |
| 'id': hw_id, |
| 'text': hw_region.get('text', ''), |
| 'author_id': hw_region.get('author_id'), |
| 'is_signature': hw_region.get('is_signature', False), |
| 'rect': hw_region.get('rect', {}), |
| 'has_image': hw_id in handwriting_images, |
| 'image_filename': f"{hw_id}.png" if hw_id in handwriting_images else None |
| } |
| mapping['handwriting']['tokens'].append(token_info) |
| |
| |
| for ve in visual_elements: |
| ve_id = ve.get('id', 'unknown') |
| ve_info = { |
| 'id': ve_id, |
| 'type': ve.get('type', 'unknown'), |
| 'content': ve.get('content'), |
| 'rect': ve.get('rect', {}), |
| 'has_image': ve_id in visual_element_images, |
| 'image_filename': f"{ve_id}.png" if ve_id in visual_element_images else None |
| } |
| mapping['visual_elements']['items'].append(ve_info) |
| |
| return mapping |
|
|
|
|
| def extract_all_bboxes_from_pdf(pdf_path: pathlib.Path) -> Dict[str, List[dict]]: |
| """ |
| Extract both word-level and character-level bounding boxes from PDF. |
| |
| This is a high-priority feature for ML datasets as it provides: |
| - Word-level bboxes: Ground truth text positions from PDF |
| - Character-level bboxes: Fine-grained localization for character recognition |
| |
| Args: |
| pdf_path: Path to PDF file |
| |
| Returns: |
| Dictionary with 'word' and 'char' keys containing bbox lists |
| """ |
| from docgenie.generation.pipeline_04.extract_bbox import extract_bboxes_from_pdf |
| |
| |
| word_bboxes_raw = extract_bboxes_from_pdf( |
| pdf_path=pdf_path, |
| level="word" |
| ) |
| |
| |
| char_bboxes_raw = extract_bboxes_from_pdf( |
| pdf_path=pdf_path, |
| level="char" |
| ) |
| |
| |
| word_bboxes = [] |
| for bbox in word_bboxes_raw: |
| word_bboxes.append({ |
| "text": bbox.text, |
| "x": bbox.x0, |
| "y": bbox.y0, |
| "width": bbox.width, |
| "height": bbox.height, |
| "bbox": [bbox.x0, bbox.y0, bbox.x2, bbox.y2], |
| "block_no": bbox.block_no, |
| "line_no": bbox.line_no, |
| "word_no": bbox.word_no, |
| "page": 0 |
| }) |
| |
| char_bboxes = [] |
| for bbox in char_bboxes_raw: |
| char_bboxes.append({ |
| "text": bbox.text, |
| "x": bbox.x0, |
| "y": bbox.y0, |
| "width": bbox.width, |
| "height": bbox.height, |
| "bbox": [bbox.x0, bbox.y0, bbox.x2, bbox.y2], |
| "block_no": bbox.block_no, |
| "line_no": bbox.line_no, |
| "word_no": bbox.word_no, |
| "page": 0 |
| }) |
| |
| return { |
| "word": word_bboxes, |
| "char": char_bboxes |
| } |
|
|
|
|
| def extract_raw_annotations_from_geometries(geometries: List[dict]) -> List[dict]: |
| """ |
| Extract raw layout annotations (bounding boxes) from geometries. |
| |
| This is a high-priority feature for ML datasets as it provides: |
| - Layout bounding boxes before any normalization |
| - Shows original coordinate space from HTML rendering |
| - Useful for debugging annotation processing pipeline |
| |
| Args: |
| geometries: List of geometry dictionaries from HTML rendering |
| |
| Returns: |
| List of layout annotation dictionaries with bbox coordinates |
| """ |
| annotations = [] |
| |
| for geom in geometries: |
| |
| class_name = geom.get('class', '') |
| if not class_name.startswith('LE-'): |
| continue |
| |
| |
| rect = geom.get('rect', {}) |
| if not rect: |
| continue |
| |
| annotation = { |
| 'class': class_name, |
| 'type': 'layout_element', |
| 'bbox': { |
| 'x': rect.get('x', 0), |
| 'y': rect.get('y', 0), |
| 'width': rect.get('width', 0), |
| 'height': rect.get('height', 0) |
| }, |
| 'text': geom.get('text', ''), |
| 'attributes': geom.get('attributes', {}) |
| } |
| |
| |
| annotation['bbox']['x2'] = annotation['bbox']['x'] + annotation['bbox']['width'] |
| annotation['bbox']['y2'] = annotation['bbox']['y'] + annotation['bbox']['height'] |
| |
| annotations.append(annotation) |
| |
| return annotations |
|
|