Spaces:
Runtime error
Runtime error
from mmdet.apis import init_detector, inference_detector | |
import gradio as gr | |
import cv2 | |
import sys | |
import torch | |
import numpy as np | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
print('Loading model...') | |
device = 'gpu' if torch.cuda.is_available() else 'cpu' | |
table_det = init_detector('model/table-det/config.py', | |
'model/table-det/model.pth', device=device) | |
def get_corners(points): | |
""" | |
Returns the top-left, top-right, bottom-right, and bottom-left corners | |
of a rectangle defined by a list of four points in the form of tuples. | |
""" | |
# Sort points by x-coordinate | |
sorted_points = sorted(points, key=lambda p: p[0]) | |
# Split sorted points into left and right halves | |
left_points = sorted_points[:2] | |
right_points = sorted_points[2:] | |
# Sort left and right points by y-coordinate | |
left_points = sorted(left_points, key=lambda p: p[1]) | |
right_points = sorted(right_points, key=lambda p: p[1], reverse=True) | |
# Return corners in order: top-left, top-right, bottom-right, bottom-left | |
return (left_points[0], right_points[0], right_points[1], left_points[1]) | |
def get_bbox(mask_array): | |
""" | |
Gets the bounding boxes of tables in a mask array. | |
Args: | |
mask_array (numpy.ndarray): The mask array to be processed. | |
Returns: | |
list[tuple(int, int, int, int)]: A list of bounding boxes, where each bounding box is a tuple of (top left x, top left y, bottom right x, bottom right y). | |
""" | |
# Find the contours in the mask array. | |
contours, hierarchy = cv2.findContours( | |
mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
# For each contour, get the bounding box. | |
table_bboxes = [] | |
for cnt in contours: | |
# Get the minimum area rectangle that encloses the contour. | |
rect = cv2.minAreaRect(cnt) | |
# Get the corners of the minimum area rectangle. | |
box = cv2.boxPoints(rect) | |
# Get the epsilon value, which is used to approximate the contour. | |
epsilon = cv2.arcLength(cnt, True) | |
# Approximate the contour using the epsilon value. | |
approx = cv2.approxPolyDP(cnt, 0.02 * epsilon, True) | |
# Get the points of the approximated contour. | |
points = np.squeeze(approx) | |
# If the number of points is not 4, then use the points of the minimum area rectangle. | |
if len(points) != 4: | |
points = box | |
# Get the top left, bottom right, bottom left, and top right corners of the bounding box. | |
tl, br, bl, tr = get_corners(points.tolist()) | |
# Add the bounding box to the list of bounding boxes. | |
table_bboxes.append([tl, tr, br, bl]) | |
# Return the list of bounding boxes. | |
return table_bboxes | |
def predict(image_input): | |
logger.info(f"Image input: {image_input}") | |
# Inference the tables in the image. | |
result = inference_detector(table_det, image_input) | |
# Get the masks of the tables. | |
mask_images = result.pred_instances.masks.cpu().numpy() | |
scores = result.pred_instances.scores.cpu().numpy() | |
bboxes = result.pred_instances.bboxes.cpu().numpy() | |
logger.info(f"Result: {result}") | |
bbox_list = [] | |
# Filter out the masks with a score less than 0.5. | |
filtered_mask_images = mask_images[scores > 0.5] | |
filtered_bboxes = bboxes[scores > 0.5] | |
# Get the bounding boxes of the tables. | |
for mask_image in filtered_mask_images: | |
bbox_list.extend(get_bbox(mask_image.astype(np.uint8))) | |
return {'rect-fit': bbox_list, 'bbox': filtered_bboxes.tolist()} | |
def run(): | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.components.Image(), | |
outputs=gr.JSON(), | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
run() | |