from mmdet.apis import init_detector, inference_detector import gradio as gr import cv2 import sys import torch import numpy as np import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) print('Loading model...') device = 'gpu' if torch.cuda.is_available() else 'cpu' table_det = init_detector('model/table-det/config.py', 'model/table-det/model.pth', device=device) def get_corners(points): """ Returns the top-left, top-right, bottom-right, and bottom-left corners of a rectangle defined by a list of four points in the form of tuples. """ # Sort points by x-coordinate sorted_points = sorted(points, key=lambda p: p[0]) # Split sorted points into left and right halves left_points = sorted_points[:2] right_points = sorted_points[2:] # Sort left and right points by y-coordinate left_points = sorted(left_points, key=lambda p: p[1]) right_points = sorted(right_points, key=lambda p: p[1], reverse=True) # Return corners in order: top-left, top-right, bottom-right, bottom-left return (left_points[0], right_points[0], right_points[1], left_points[1]) def get_bbox(mask_array): """ Gets the bounding boxes of tables in a mask array. Args: mask_array (numpy.ndarray): The mask array to be processed. Returns: list[tuple(int, int, int, int)]: A list of bounding boxes, where each bounding box is a tuple of (top left x, top left y, bottom right x, bottom right y). """ # Find the contours in the mask array. contours, hierarchy = cv2.findContours( mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # For each contour, get the bounding box. table_bboxes = [] for cnt in contours: # Get the minimum area rectangle that encloses the contour. rect = cv2.minAreaRect(cnt) # Get the corners of the minimum area rectangle. box = cv2.boxPoints(rect) # Get the epsilon value, which is used to approximate the contour. epsilon = cv2.arcLength(cnt, True) # Approximate the contour using the epsilon value. approx = cv2.approxPolyDP(cnt, 0.02 * epsilon, True) # Get the points of the approximated contour. points = np.squeeze(approx) # If the number of points is not 4, then use the points of the minimum area rectangle. if len(points) != 4: points = box # Get the top left, bottom right, bottom left, and top right corners of the bounding box. tl, br, bl, tr = get_corners(points.tolist()) # Add the bounding box to the list of bounding boxes. table_bboxes.append([tl, tr, br, bl]) # Return the list of bounding boxes. return table_bboxes def predict(image_input): logger.info(f"Image input: {image_input}") # Inference the tables in the image. result = inference_detector(table_det, image_input) # Get the masks of the tables. mask_images = result.pred_instances.masks.cpu().numpy() scores = result.pred_instances.scores.cpu().numpy() bboxes = result.pred_instances.bboxes.cpu().numpy() logger.info(f"Result: {result}") bbox_list = [] # Filter out the masks with a score less than 0.5. filtered_mask_images = mask_images[scores > 0.5] filtered_bboxes = bboxes[scores > 0.5] # Get the bounding boxes of the tables. for mask_image in filtered_mask_images: bbox_list.extend(get_bbox(mask_image.astype(np.uint8))) return {'rect-fit': bbox_list, 'bbox': filtered_bboxes.tolist()} def run(): demo = gr.Interface( fn=predict, inputs=gr.components.Image(), outputs=gr.JSON(), ) demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": run()