table-det / main.py
napatswift
Update app
992ad70
raw
history blame
3.47 kB
from mmdet.apis import init_detector, inference_detector
import gradio as gr
import cv2
import sys
import torch
import numpy as np
print('Loading model...')
device = 'gpu' if torch.cuda.is_available() else 'cpu'
table_det = init_detector('model/table-det/config.py',
'model/table-det/model.pth', device=device)
def get_corners(points):
"""
Returns the top-left, top-right, bottom-right, and bottom-left corners
of a rectangle defined by a list of four points in the form of tuples.
"""
# Sort points by x-coordinate
sorted_points = sorted(points, key=lambda p: p[0])
# Split sorted points into left and right halves
left_points = sorted_points[:2]
right_points = sorted_points[2:]
# Sort left and right points by y-coordinate
left_points = sorted(left_points, key=lambda p: p[1])
right_points = sorted(right_points, key=lambda p: p[1], reverse=True)
# Return corners in order: top-left, top-right, bottom-right, bottom-left
return (left_points[0], right_points[0], right_points[1], left_points[1])
def get_bbox(mask_array):
"""
Gets the bounding boxes of tables in a mask array.
Args:
mask_array (numpy.ndarray): The mask array to be processed.
Returns:
list[tuple(int, int, int, int)]: A list of bounding boxes, where each bounding box is a tuple of (top left x, top left y, bottom right x, bottom right y).
"""
# Find the contours in the mask array.
contours, hierarchy = cv2.findContours(
mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# For each contour, get the bounding box.
table_bboxes = []
for cnt in contours:
# Get the minimum area rectangle that encloses the contour.
rect = cv2.minAreaRect(cnt)
# Get the corners of the minimum area rectangle.
box = cv2.boxPoints(rect)
# Get the epsilon value, which is used to approximate the contour.
epsilon = cv2.arcLength(cnt, True)
# Approximate the contour using the epsilon value.
approx = cv2.approxPolyDP(cnt, 0.02 * epsilon, True)
# Get the points of the approximated contour.
points = np.squeeze(approx)
# If the number of points is not 4, then use the points of the minimum area rectangle.
if len(points) != 4:
points = box
# Get the top left, bottom right, bottom left, and top right corners of the bounding box.
tl, br, bl, tr = get_corners(points.tolist())
# Add the bounding box to the list of bounding boxes.
table_bboxes.append([tl, tr, br, bl])
# Return the list of bounding boxes.
return table_bboxes
def predict(image_input):
# Inference the tables in the image.
result = inference_detector(table_det, image_input)
# Get the masks of the tables.
mask_images = result.pred_instances.masks.cpu().numpy()
scores = result.pred_instances.scores.cpu().numpy()
bbox_list = []
# Filter out the masks with a score less than 0.5.
filtered_mask_images = mask_images
# Get the bounding boxes of the tables.
for mask_image in mask_images:
bbox_list.extend(get_bbox(mask_image.astype(np.uint8)))
return bbox_list
def run():
demo = gr.Interface(
fn=predict,
inputs=gr.components.Image(),
outputs=gr.JSON(),
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()