import os os.environ["GRADIO_TEMP_DIR"] = "./tmp" import sys import torch import torchvision import gradio as gr import numpy as np import cv2 from PIL import Image from transformers import ( DFineForObjectDetection, RTDetrV2ForObjectDetection, RTDetrImageProcessor, ) # == Device configuration == device = 'cuda' if torch.cuda.is_available() else 'cpu' # == Model configurations == MODELS = { "Egret XLarge": { "path": "ds4sd/docling-layout-egret-xlarge", "model_class": DFineForObjectDetection }, "Egret Large": { "path": "ds4sd/docling-layout-egret-large", "model_class": DFineForObjectDetection }, "Egret Medium": { "path": "ds4sd/docling-layout-egret-medium", "model_class": DFineForObjectDetection }, "Heron 101": { "path": "ds4sd/docling-layout-heron-101", "model_class": RTDetrV2ForObjectDetection }, "Heron": { "path": "ds4sd/docling-layout-heron", "model_class": RTDetrV2ForObjectDetection } } # == Class mappings == classes_map = { 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", 15: "Form", 16: "Key-Value Region", } # == Global model variables == current_model = None current_processor = None current_model_name = None def colormap(N=256, normalized=False): """Generate dynamic colormap.""" def bitget(byteval, idx): return ((byteval & (1 << idx)) != 0) cmap = np.zeros((N, 3), dtype=np.uint8) for i in range(N): r = g = b = 0 c = i for j in range(8): r = r | (bitget(c, 0) << (7 - j)) g = g | (bitget(c, 1) << (7 - j)) b = b | (bitget(c, 2) << (7 - j)) c = c >> 3 cmap[i] = np.array([r, g, b]) if normalized: cmap = cmap.astype(np.float32) / 255.0 return cmap def iomin(box1, box2): """Intersection over Minimum (IoMin).""" x1 = torch.max(box1[:, 0], box2[:, 0]) y1 = torch.max(box1[:, 1], box2[:, 1]) x2 = torch.min(box1[:, 2], box2[:, 2]) y2 = torch.min(box1[:, 3], box2[:, 3]) inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) min_area = torch.min(box1_area, box2_area) return inter_area / min_area def nms_custom(boxes, scores, iou_threshold=0.5): """Custom NMS implementation using IoMin.""" keep = [] _, order = scores.sort(descending=True) while order.numel() > 0: i = order[0] keep.append(i.item()) if order.numel() == 1: break box_i = boxes[i].unsqueeze(0) rest = order[1:] ious = iomin(box_i, boxes[rest]) mask = (ious <= iou_threshold) order = order[1:][mask] return torch.tensor(keep, dtype=torch.long) def load_model(model_name): """Load the selected model.""" global current_model, current_processor, current_model_name if current_model_name == model_name: return f"✅ Model {model_name} is already loaded!" try: model_info = MODELS[model_name] model_path = model_info["path"] model_class = model_info["model_class"] print(f"Loading {model_name} from {model_path}") processor = RTDetrImageProcessor.from_pretrained(model_path) model = model_class.from_pretrained(model_path) model = model.to(device) model.eval() current_processor = processor current_model = model current_model_name = model_name return f"✅ Successfully loaded {model_name}!" except Exception as e: print(f"Error loading model: {e}") return f"❌ Error loading {model_name}: {str(e)}" def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True): """Visualize bounding boxes with OpenCV.""" if isinstance(image_input, Image.Image): image = np.array(image_input) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) elif isinstance(image_input, np.ndarray): if len(image_input.shape) == 3 and image_input.shape[2] == 3: image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) else: image = image_input.copy() else: raise ValueError("Input must be PIL Image or numpy array") if len(bboxes) == 0: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) overlay = image.copy() cmap = colormap(N=len(id_to_names), normalized=False) for i in range(len(bboxes)): try: bbox = bboxes[i] if torch.is_tensor(bbox): bbox = bbox.cpu().numpy() class_id = classes[i] if torch.is_tensor(class_id): class_id = class_id.item() score = scores[i] if torch.is_tensor(score): score = score.item() x_min, y_min, x_max, y_max = map(int, bbox) class_id = int(class_id) class_name = id_to_names.get(class_id, f"unknown_{class_id}") color = tuple(int(c) for c in cmap[class_id % len(cmap)]) # Draw filled rectangle on overlay cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) # Draw border on main image cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) # Add text label only if show_labels is True if show_labels: text = f"{class_name}: {score:.3f}" (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4), (x_min + text_width + 8, y_min), color, -1) cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) except Exception as e: print(f"Skipping box {i} due to error: {e}") # Apply transparency cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, show_labels): """Process image with document layout detection.""" if input_img is None: return None, "❌ Please upload an image first." if current_model is None or current_processor is None: return None, "❌ Please load a model first." try: # Prepare image if isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img) if input_img.mode != 'RGB': input_img = input_img.convert('RGB') # Process with model inputs = current_processor(images=[input_img], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = current_model(**inputs) # Post-process results results = current_processor.post_process_object_detection( outputs, target_sizes=torch.tensor([input_img.size[::-1]]), threshold=conf_threshold, ) if not results or len(results) == 0: return np.array(input_img), "ℹ️ No detections found." result = results[0] boxes = result["boxes"] scores = result["scores"] labels = result["labels"] if len(boxes) == 0: return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}." # Apply NMS if iou_threshold < 1.0: if nms_method == "Custom IoMin": keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold) else: # Use torchvision NMS with correct format keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) boxes = boxes[keep_indices] scores = scores[keep_indices] labels = labels[keep_indices] # Visualize results output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) labels_status = "with labels" if show_labels else "without labels" info = f"✅ Found {len(boxes)} detections ({labels_status}) | NMS: {nms_method} | Threshold: {conf_threshold:.2f}" return output, info except Exception as e: print(f"[ERROR] process_image failed: {e}") error_msg = f"❌ Processing error: {str(e)}" if input_img is not None: return np.array(input_img), error_msg return np.zeros((512, 512, 3), dtype=np.uint8), error_msg def reset_interface(): """Reset all interface components.""" return gr.update(value=None), gr.update(value=None), gr.update(value="") if __name__ == "__main__": print(f"🚀 Starting Document Layout Analysis App") print(f"📱 Device: {device}") print(f"🤖 Available models: {len(MODELS)}") # Custom CSS for full-width layout custom_css = """ .gradio-container { max-width: 100% !important; padding: 20px !important; } .main-container { width: 100% !important; max-width: none !important; } .panel-left, .panel-right { min-height: 600px; padding: 20px; background: #f8f9fa; border-radius: 12px; border: 1px solid #e9ecef; } .control-section { margin-bottom: 20px; padding: 15px; background: white; border-radius: 8px; border: 1px solid #dee2e6; } .status-good { color: #28a745; font-weight: bold; } .status-error { color: #dc3545; font-weight: bold; } .status-info { color: #17a2b8; font-weight: bold; } .toggle-labels { background: linear-gradient(45deg, #667eea, #764ba2) !important; border: none !important; color: white !important; font-weight: bold !important; } """ # Create Gradio interface with gr.Blocks( title="📄 Document Layout Analysis - Full Width", theme=gr.themes.Soft(), css=custom_css ) as demo: # Header gr.HTML("""

🔍 Document Layout Analysis

Advanced document structure detection with multiple AI models

""") # Main content in two columns with gr.Row(): # LEFT COLUMN - Controls and Input with gr.Column(scale=1, elem_classes=["panel-left"]): # Model Section with gr.Group(elem_classes=["control-section"]): gr.HTML("

🤖 Model Configuration

") model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="Egret XLarge", label="Select Model", info="Choose the AI model for document analysis", interactive=True ) with gr.Row(): load_btn = gr.Button("📥 Load Model", variant="primary", scale=1) clear_btn = gr.Button("🗑️ Clear All", variant="secondary", scale=1) model_status = gr.Textbox( label="Model Status", value="🔄 No model loaded. Please select and load a model.", interactive=False, lines=2 ) # Image Upload Section with gr.Group(elem_classes=["control-section"]): gr.HTML("

📄 Image Input

") input_img = gr.Image( label="Upload Document Image", type="pil", height=400, interactive=True ) detect_btn = gr.Button("🔍 Analyze Document", variant="primary", size="lg") # Parameters Section with gr.Group(elem_classes=["control-section"]): gr.HTML("

⚙️ Detection Parameters

") conf_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold", info="Minimum confidence for detections" ) iou_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="NMS IoU Threshold", info="Non-maximum suppression threshold" ) nms_method = gr.Radio( choices=["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Algorithm", info="Choose suppression method" ) alpha_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Overlay Transparency", info="Transparency of detection overlays" ) # RIGHT COLUMN - Results and Output with gr.Column(scale=1, elem_classes=["panel-right"]): # Results Section with gr.Group(elem_classes=["control-section"]): gr.HTML("

🎯 Detection Results

") output_img = gr.Image( label="Analyzed Document", type="numpy", height=500, interactive=False ) detection_info = gr.Textbox( label="Analysis Summary", value="", interactive=False, lines=3, placeholder="Detection results will appear here..." ) # Visualization Options Section with gr.Group(elem_classes=["control-section"]): gr.HTML("

🎨 Visualization Options

") show_labels_checkbox = gr.Checkbox( value=True, label="Show Class Labels", info="Display class names and confidence scores on detections", interactive=True ) # Event Handlers load_btn.click( fn=load_model, inputs=[model_dropdown], outputs=[model_status] ) clear_btn.click( fn=reset_interface, outputs=[input_img, output_img, detection_info] ) detect_btn.click( fn=process_image, inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox], outputs=[output_img, detection_info] ) # Launch application demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, share=False, show_error=True, inbrowser=True )