Spaces:
Running
Running
| import os | |
| os.environ["GRADIO_TEMP_DIR"] = "./tmp" | |
| import sys | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import ( | |
| DFineForObjectDetection, | |
| RTDetrV2ForObjectDetection, | |
| RTDetrImageProcessor, | |
| ) | |
| # == select device == | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Available models with their corresponding model classes | |
| MODELS = { | |
| "Egret XLarge": { | |
| "path": "ds4sd/docling-layout-egret-xlarge", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Egret Large": { | |
| "path": "ds4sd/docling-layout-egret-large", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Egret Medium": { | |
| "path": "ds4sd/docling-layout-egret-medium", | |
| "model_class": DFineForObjectDetection | |
| }, | |
| "Heron 101": { | |
| "path": "ds4sd/docling-layout-heron-101", | |
| "model_class": RTDetrV2ForObjectDetection | |
| }, | |
| "Heron": { | |
| "path": "ds4sd/docling-layout-heron", | |
| "model_class": RTDetrV2ForObjectDetection | |
| } | |
| } | |
| # Classes mapping for the docling model | |
| classes_map = { | |
| 0: "Caption", | |
| 1: "Footnote", | |
| 2: "Formula", | |
| 3: "List-item", | |
| 4: "Page-footer", | |
| 5: "Page-header", | |
| 6: "Picture", | |
| 7: "Section-header", | |
| 8: "Table", | |
| 9: "Text", | |
| 10: "Title", | |
| 11: "Document Index", | |
| 12: "Code", | |
| 13: "Checkbox-Selected", | |
| 14: "Checkbox-Unselected", | |
| 15: "Form", | |
| 16: "Key-Value Region", | |
| } | |
| # Global variables for model | |
| current_model = None | |
| current_processor = None | |
| current_model_name = None | |
| def colormap(N=256, normalized=False): | |
| """Generate the color map.""" | |
| def bitget(byteval, idx): | |
| return ((byteval & (1 << idx)) != 0) | |
| cmap = np.zeros((N, 3), dtype=np.uint8) | |
| for i in range(N): | |
| r = g = b = 0 | |
| c = i | |
| for j in range(8): | |
| r = r | (bitget(c, 0) << (7 - j)) | |
| g = g | (bitget(c, 1) << (7 - j)) | |
| b = b | (bitget(c, 2) << (7 - j)) | |
| c = c >> 3 | |
| cmap[i] = np.array([r, g, b]) | |
| if normalized: | |
| cmap = cmap.astype(np.float32) / 255.0 | |
| return cmap | |
| def iomin(box1, box2): | |
| """Intersection over Minimum (IoMin)""" | |
| x1 = torch.max(box1[:, 0], box2[:, 0]) | |
| y1 = torch.max(box1[:, 1], box2[:, 1]) | |
| x2 = torch.min(box1[:, 2], box2[:, 2]) | |
| y2 = torch.min(box1[:, 3], box2[:, 3]) | |
| inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0) | |
| box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) | |
| box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) | |
| min_area = torch.min(box1_area, box2_area) | |
| return inter_area / min_area | |
| def nms(boxes, scores, iou_threshold=0.5): | |
| """Custom NMS implementation using IoMin""" | |
| keep = [] | |
| _, order = scores.sort(descending=True) | |
| while order.numel() > 0: | |
| i = order[0] | |
| keep.append(i.item()) | |
| if order.numel() == 1: | |
| break | |
| box_i = boxes[i].unsqueeze(0) | |
| rest = order[1:] | |
| ious = iomin(box_i, boxes[rest]) | |
| mask = (ious <= iou_threshold) | |
| order = order[1:][mask] | |
| return torch.tensor(keep, dtype=torch.long) | |
| def load_model(model_name): | |
| """Load the selected model""" | |
| global current_model, current_processor, current_model_name | |
| if current_model_name == model_name: | |
| return f"β Model {model_name} is already loaded!" | |
| try: | |
| print(f"Loading model: {model_name}") | |
| model_info = MODELS[model_name] | |
| model_path = model_info["path"] | |
| model_class = model_info["model_class"] | |
| processor = RTDetrImageProcessor.from_pretrained(model_path) | |
| model = model_class.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| current_processor = processor | |
| current_model = model | |
| current_model_name = model_name | |
| return f"β Successfully loaded {model_name}!" | |
| except Exception as e: | |
| return f"β Error loading {model_name}: {str(e)}" | |
| def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3): | |
| """Visualize bounding boxes with transparent overlays using OpenCV""" | |
| if isinstance(image_input, Image.Image): | |
| image = np.array(image_input) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| elif isinstance(image_input, np.ndarray): | |
| if len(image_input.shape) == 3 and image_input.shape[2] == 3: | |
| image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) | |
| else: | |
| image = image_input.copy() | |
| else: | |
| raise ValueError("Input must be PIL Image or numpy array") | |
| overlay = image.copy() | |
| cmap = colormap(N=len(id_to_names), normalized=False) | |
| if len(bboxes) == 0: | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| for i in range(len(bboxes)): | |
| try: | |
| bbox = bboxes[i] | |
| if torch.is_tensor(bbox): | |
| bbox = bbox.cpu().numpy() | |
| class_id = classes[i] | |
| if torch.is_tensor(class_id): | |
| class_id = class_id.item() | |
| score = scores[i] | |
| if torch.is_tensor(score): | |
| score = score.item() | |
| x_min, y_min, x_max, y_max = map(int, bbox) | |
| class_id = int(class_id) | |
| class_name = id_to_names.get(class_id, f"unknown_{class_id}") | |
| text = f"{class_name}:{score:.3f}" | |
| color = tuple(int(c) for c in cmap[class_id % len(cmap)]) | |
| cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1) | |
| cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2) | |
| (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) | |
| cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1) | |
| cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| except Exception as e: | |
| print(f"Skipping box {i} due to error: {e}") | |
| cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) | |
| return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha): | |
| """Process image with docling layout model""" | |
| if input_img is None: | |
| return None, "Please upload an image first." | |
| if current_model is None or current_processor is None: | |
| return None, "Please load a model first." | |
| try: | |
| if isinstance(input_img, np.ndarray): | |
| input_img = Image.fromarray(input_img) | |
| if input_img.mode != 'RGB': | |
| input_img = input_img.convert('RGB') | |
| inputs = current_processor(images=[input_img], return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = current_model(**inputs) | |
| results = current_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=torch.tensor([input_img.size[::-1]]), | |
| threshold=conf_threshold, | |
| ) | |
| if not results or len(results) == 0: | |
| return np.array(input_img), "No detections found." | |
| result = results[0] | |
| boxes = result["boxes"] | |
| scores = result["scores"] | |
| labels = result["labels"] | |
| if len(boxes) == 0: | |
| return np.array(input_img), "No detections above confidence threshold." | |
| if iou_threshold < 1.0: | |
| if nms_method == "Custom IoMin": | |
| keep_indices = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold) | |
| else: | |
| keep_indices = torch.ops.torchvision.nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold) | |
| boxes = boxes[keep_indices] | |
| scores = scores[keep_indices] | |
| labels = labels[keep_indices] | |
| if len(boxes.shape) == 1: | |
| boxes = boxes.unsqueeze(0) | |
| scores = scores.unsqueeze(0) | |
| labels = labels.unsqueeze(0) | |
| output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha) | |
| detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})" | |
| return output, detection_info | |
| except Exception as e: | |
| print(f"[ERROR] recognize_image failed: {e}") | |
| error_msg = f"Error during processing: {str(e)}" | |
| if input_img is not None: | |
| return np.array(input_img), error_msg | |
| return np.zeros((512, 512, 3), dtype=np.uint8), error_msg | |
| def gradio_reset(): | |
| return gr.update(value=None), gr.update(value=None), gr.update(value="") | |
| if __name__ == "__main__": | |
| print(f"Using device: {device}") | |
| # Custom CSS for better scrolling and layout | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| .main-content { | |
| overflow-y: auto !important; | |
| max-height: 100vh !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>π Document Layout Analysis</h1> | |
| <p>Using Docling Layout Models for document structure detection</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Column - Controls | |
| with gr.Column(scale=1): | |
| # Model selection | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Egret XLarge", | |
| label="π€ Select Model" | |
| ) | |
| load_btn = gr.Button("π₯ Load Model", variant="secondary", size="sm") | |
| model_status = gr.Textbox(label="Model Status", interactive=False, value="No model loaded", max_lines=2) | |
| input_img = gr.Image(label="π Upload Image", type="pil", height=300) | |
| with gr.Row(): | |
| clear = gr.Button("ποΈ Clear", size="sm") | |
| predict = gr.Button("π Detect", variant="primary", size="sm") | |
| # Parameters | |
| conf_threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Confidence Threshold") | |
| iou_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="NMS IoU Threshold") | |
| nms_method = gr.Radio(["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Method") | |
| alpha_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.1, label="Overlay Transparency") | |
| # Right Column - Results | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>π― Detection Results</h3>") | |
| output_img = gr.Image(label="Detected Layout", interactive=False, type="numpy", height=400) | |
| detection_info = gr.Textbox(label="Detection Info", interactive=False, max_lines=2) | |
| # Legend at the bottom | |
| with gr.Accordion("π Detected Classes", open=False): | |
| cmap = colormap(N=len(classes_map), normalized=False) | |
| legend_items = [] | |
| for class_id, class_name in classes_map.items(): | |
| color_rgb = cmap[class_id % len(cmap)] | |
| color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}" | |
| legend_items.append(f'<span style="display:inline-block;width:15px;height:15px;background-color:{color_hex};margin-right:5px;border:1px solid #ccc;"></span>{class_name}') | |
| legend_html = f""" | |
| <div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; font-size: 14px;'> | |
| {''.join([f'<div>{item}</div>' for item in legend_items])} | |
| </div> | |
| """ | |
| gr.HTML(legend_html) | |
| # Event handlers | |
| load_btn.click(load_model, inputs=[model_dropdown], outputs=[model_status]) | |
| clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img, detection_info]) | |
| predict.click( | |
| recognize_image, | |
| inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider], | |
| outputs=[output_img, detection_info] | |
| ) | |
| # Launch | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, share=False) |