Spaces:
Running
Running
| import os | |
| os.environ["GRADIO_TEMP_DIR"] = "./tmp" | |
| import sys | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import ( | |
| DFineForObjectDetection, | |
| RTDetrV2ForObjectDetection, | |
| RTDetrImageProcessor, | |
| ) | |
| # == Device configuration == | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # == Model configurations == | |
| MODELS = { | |
| "Egret XLarge": { | |
| "path": "ds4sd/docling-layout-egret-xlarge", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Egret Large": { | |
| "path": "ds4sd/docling-layout-egret-large", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Egret Medium": { | |
| "path": "ds4sd/docling-layout-egret-medium", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Heron 101": { | |
| "path": "ds4sd/docling-layout-heron-101", | |
| "model_class": RTDetrV2ForObjectDetection | |
| }, | |
| "Heron": { | |
| "path": "ds4sd/docling-layout-heron", | |
| "model_class": RTDetrV2ForObjectDetection | |
| } | |
| } | |
| # == Class mappings == | |
| classes_map = { | |
| 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item", | |
| 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header", | |
| 8: "Table", 9: "Text", 10: "Title", 11: "Document Index", | |
| 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected", | |
| 15: "Form", 16: "Key-Value Region", | |
| } | |
| # == Global model variables == | |
| current_model = None | |
| current_processor = None | |
| current_model_name = None | |
| def colormap(N=256, normalized=False): | |
| """Generate dynamic colormap.""" | |
| def bitget(byteval, idx): | |
| return ((byteval & (1 << idx)) != 0) | |
| cmap = np.zeros((N, 3), dtype=np.uint8) | |
| for i in range(N): | |
| r = g = b = 0 | |
| c = i | |
| for j in range(8): | |
| r = r | (bitget(c, 0) << (7 - j)) | |
| g = g | (bitget(c, 1) << (7 - j)) | |
| b = b | (bitget(c, 2) << (7 - j)) | |
| c = c >> 3 | |
| cmap[i] = np.array([r, g, b]) | |
| if normalized: | |
| cmap = cmap.astype(np.float32) / 255.0 | |
| return cmap | |
| def iomin(box1, box2): | |
| """Intersection over Minimum (IoMin).""" | |
| x1 = torch.max(box1[:, 0], box2[:, 0]) | |
| y1 = torch.max(box1[:, 1], box2[:, 1]) | |
| x2 = torch.min(box1[:, 2], box2[:, 2]) | |
| y2 = torch.min(box1[:, 3], box2[:, 3]) | |
| inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) | |
| box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) | |
| box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) | |
| min_area = torch.min(box1_area, box2_area) | |
| return inter_area / min_area | |
| def nms_custom(boxes, scores, iou_threshold=0.5): | |
| """Custom NMS implementation using IoMin.""" | |
| keep = [] | |
| _, order = scores.sort(descending=True) | |
| while order.numel() > 0: | |
| i = order[0] | |
| keep.append(i.item()) | |
| if order.numel() == 1: | |
| break | |
| box_i = boxes[i].unsqueeze(0) | |
| rest = order[1:] | |
| ious = iomin(box_i, boxes[rest]) | |
| mask = (ious <= iou_threshold) | |
| order = order[1:][mask] | |
| return torch.tensor(keep, dtype=torch.long) | |
| def load_model(model_name): | |
| """Load the selected model.""" | |
| global current_model, current_processor, current_model_name | |
| if current_model_name == model_name: | |
| return f"β Model {model_name} is already loaded!" | |
| try: | |
| model_info = MODELS[model_name] | |
| model_path = model_info["path"] | |
| model_class = model_info["model_class"] | |
| print(f"Loading {model_name} from {model_path}") | |
| processor = RTDetrImageProcessor.from_pretrained(model_path) | |
| model = model_class.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| current_processor = processor | |
| current_model = model | |
| current_model_name = model_name | |
| return f"β Successfully loaded {model_name}!" | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return f"β Error loading {model_name}: {str(e)}" | |
| def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True): | |
| """Visualize bounding boxes with OpenCV.""" | |
| if isinstance(image_input, Image.Image): | |
| image = np.array(image_input) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| elif isinstance(image_input, np.ndarray): | |
| if len(image_input.shape) == 3 and image_input.shape[2] == 3: | |
| image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) | |
| else: | |
| image = image_input.copy() | |
| else: | |
| raise ValueError("Input must be PIL Image or numpy array") | |
| if len(bboxes) == 0: | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| overlay = image.copy() | |
| cmap = colormap(N=len(id_to_names), normalized=False) | |
| for i in range(len(bboxes)): | |
| try: | |
| bbox = bboxes[i] | |
| if torch.is_tensor(bbox): | |
| bbox = bbox.cpu().numpy() | |
| class_id = classes[i] | |
| if torch.is_tensor(class_id): | |
| class_id = class_id.item() | |
| score = scores[i] | |
| if torch.is_tensor(score): | |
| score = score.item() | |
| x_min, y_min, x_max, y_max = map(int, bbox) | |
| class_id = int(class_id) | |
| class_name = id_to_names.get(class_id, f"unknown_{class_id}") | |
| color = tuple(int(c) for c in cmap[class_id % len(cmap)]) | |
| # Draw filled rectangle on overlay | |
| cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) | |
| # Draw border on main image | |
| cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3) | |
| # Add text label only if show_labels is True | |
| if show_labels: | |
| text = f"{class_name}: {score:.3f}" | |
| (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) | |
| cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4), | |
| (x_min + text_width + 8, y_min), color, -1) | |
| cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) | |
| except Exception as e: | |
| print(f"Skipping box {i} due to error: {e}") | |
| # Apply transparency | |
| cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, show_labels): | |
| """Process image with document layout detection.""" | |
| if input_img is None: | |
| return None, "β Please upload an image first." | |
| if current_model is None or current_processor is None: | |
| return None, "β Please load a model first." | |
| try: | |
| # Prepare image | |
| if isinstance(input_img, np.ndarray): | |
| input_img = Image.fromarray(input_img) | |
| if input_img.mode != 'RGB': | |
| input_img = input_img.convert('RGB') | |
| # Process with model | |
| inputs = current_processor(images=[input_img], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = current_model(**inputs) | |
| # Post-process results | |
| results = current_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([input_img.size[::-1]]), | |
| threshold=conf_threshold, | |
| ) | |
| if not results or len(results) == 0: | |
| return np.array(input_img), "βΉοΈ No detections found." | |
| result = results[0] | |
| boxes = result["boxes"] | |
| scores = result["scores"] | |
| labels = result["labels"] | |
| if len(boxes) == 0: | |
| return np.array(input_img), f"βΉοΈ No detections above threshold {conf_threshold:.2f}." | |
| # Apply NMS | |
| if iou_threshold < 1.0: | |
| if nms_method == "Custom IoMin": | |
| keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold) | |
| else: | |
| # Use torchvision NMS with correct format | |
| keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) | |
| boxes = boxes[keep_indices] | |
| scores = scores[keep_indices] | |
| labels = labels[keep_indices] | |
| # Visualize results | |
| output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels) | |
| labels_status = "with labels" if show_labels else "without labels" | |
| info = f"β Found {len(boxes)} detections ({labels_status}) | NMS: {nms_method} | Threshold: {conf_threshold:.2f}" | |
| return output, info | |
| except Exception as e: | |
| print(f"[ERROR] process_image failed: {e}") | |
| error_msg = f"β Processing error: {str(e)}" | |
| if input_img is not None: | |
| return np.array(input_img), error_msg | |
| return np.zeros((512, 512, 3), dtype=np.uint8), error_msg | |
| def reset_interface(): | |
| """Reset all interface components.""" | |
| return gr.update(value=None), gr.update(value=None), gr.update(value="") | |
| if __name__ == "__main__": | |
| print(f"π Starting Document Layout Analysis App") | |
| print(f"π± Device: {device}") | |
| print(f"π€ Available models: {len(MODELS)}") | |
| # Custom CSS for full-width layout | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 20px !important; | |
| } | |
| .main-container { | |
| width: 100% !important; | |
| max-width: none !important; | |
| } | |
| .panel-left, .panel-right { | |
| min-height: 600px; | |
| padding: 20px; | |
| background: #f8f9fa; | |
| border-radius: 12px; | |
| border: 1px solid #e9ecef; | |
| } | |
| .control-section { | |
| margin-bottom: 20px; | |
| padding: 15px; | |
| background: white; | |
| border-radius: 8px; | |
| border: 1px solid #dee2e6; | |
| } | |
| .status-good { color: #28a745; font-weight: bold; } | |
| .status-error { color: #dc3545; font-weight: bold; } | |
| .status-info { color: #17a2b8; font-weight: bold; } | |
| .toggle-labels { | |
| background: linear-gradient(45deg, #667eea, #764ba2) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="π Document Layout Analysis - Full Width", | |
| theme=gr.themes.Soft(), | |
| css=custom_css | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'> | |
| <h1 style='margin: 0; font-size: 3em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>π Document Layout Analysis</h1> | |
| <p style='margin: 10px 0 0 0; font-size: 1.3em; opacity: 0.9;'>Advanced document structure detection with multiple AI models</p> | |
| </div> | |
| """) | |
| # Main content in two columns | |
| with gr.Row(): | |
| # LEFT COLUMN - Controls and Input | |
| with gr.Column(scale=1, elem_classes=["panel-left"]): | |
| # Model Section | |
| with gr.Group(elem_classes=["control-section"]): | |
| gr.HTML("<h3>π€ Model Configuration</h3>") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Egret XLarge", | |
| label="Select Model", | |
| info="Choose the AI model for document analysis", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| load_btn = gr.Button("π₯ Load Model", variant="primary", scale=1) | |
| clear_btn = gr.Button("ποΈ Clear All", variant="secondary", scale=1) | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value="π No model loaded. Please select and load a model.", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| # Image Upload Section | |
| with gr.Group(elem_classes=["control-section"]): | |
| gr.HTML("<h3>π Image Input</h3>") | |
| input_img = gr.Image( | |
| label="Upload Document Image", | |
| type="pil", | |
| height=400, | |
| interactive=True | |
| ) | |
| detect_btn = gr.Button("π Analyze Document", variant="primary", size="lg") | |
| # Parameters Section | |
| with gr.Group(elem_classes=["control-section"]): | |
| gr.HTML("<h3>βοΈ Detection Parameters</h3>") | |
| conf_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| info="Minimum confidence for detections" | |
| ) | |
| iou_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| label="NMS IoU Threshold", | |
| info="Non-maximum suppression threshold" | |
| ) | |
| nms_method = gr.Radio( | |
| choices=["Custom IoMin", "Standard IoU"], | |
| value="Custom IoMin", | |
| label="NMS Algorithm", | |
| info="Choose suppression method" | |
| ) | |
| alpha_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.1, | |
| label="Overlay Transparency", | |
| info="Transparency of detection overlays" | |
| ) | |
| # RIGHT COLUMN - Results and Output | |
| with gr.Column(scale=1, elem_classes=["panel-right"]): | |
| # Results Section | |
| with gr.Group(elem_classes=["control-section"]): | |
| gr.HTML("<h3>π― Detection Results</h3>") | |
| output_img = gr.Image( | |
| label="Analyzed Document", | |
| type="numpy", | |
| height=500, | |
| interactive=False | |
| ) | |
| detection_info = gr.Textbox( | |
| label="Analysis Summary", | |
| value="", | |
| interactive=False, | |
| lines=3, | |
| placeholder="Detection results will appear here..." | |
| ) | |
| # Visualization Options Section | |
| with gr.Group(elem_classes=["control-section"]): | |
| gr.HTML("<h3>π¨ Visualization Options</h3>") | |
| show_labels_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Show Class Labels", | |
| info="Display class names and confidence scores on detections", | |
| interactive=True | |
| ) | |
| # Event Handlers | |
| load_btn.click( | |
| fn=load_model, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| clear_btn.click( | |
| fn=reset_interface, | |
| outputs=[input_img, output_img, detection_info] | |
| ) | |
| detect_btn.click( | |
| fn=process_image, | |
| inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Launch application | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| share=False, | |
| show_error=True, | |
| inbrowser=True | |
| ) |