Spaces:
Runtime error
Runtime error
File size: 3,827 Bytes
b7f49b8 a03e1e7 b7f49b8 b01f517 b7f49b8 b01f517 b7f49b8 b01f517 b7f49b8 b01f517 b7f49b8 b01f517 b7f49b8 a03e1e7 1e41e52 b01f517 ca40c0f 1e41e52 a03e1e7 ca40c0f 00196b8 b01f517 00196b8 1e41e52 567c46b 00196b8 1e41e52 00196b8 ba7e7bf b01f517 b7f49b8 e1ed59b b7f49b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from mmdet.apis import init_detector, inference_detector
import gradio as gr
import cv2
import sys
import torch
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print('Loading model...')
device = 'gpu' if torch.cuda.is_available() else 'cpu'
table_det = init_detector('model/table-det/config.py',
'model/table-det/model.pth', device=device)
def get_corners(points):
"""
Returns the top-left, top-right, bottom-right, and bottom-left corners
of a rectangle defined by a list of four points in the form of tuples.
"""
# Sort points by x-coordinate
sorted_points = sorted(points, key=lambda p: p[0])
# Split sorted points into left and right halves
left_points = sorted_points[:2]
right_points = sorted_points[2:]
# Sort left and right points by y-coordinate
left_points = sorted(left_points, key=lambda p: p[1])
right_points = sorted(right_points, key=lambda p: p[1], reverse=True)
# Return corners in order: top-left, top-right, bottom-right, bottom-left
return (left_points[0], right_points[0], right_points[1], left_points[1])
def get_bbox(mask_array):
"""
Gets the bounding boxes of tables in a mask array.
Args:
mask_array (numpy.ndarray): The mask array to be processed.
Returns:
list[tuple(int, int, int, int)]: A list of bounding boxes, where each bounding box is a tuple of (top left x, top left y, bottom right x, bottom right y).
"""
# Find the contours in the mask array.
contours, hierarchy = cv2.findContours(
mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# For each contour, get the bounding box.
table_bboxes = []
for cnt in contours:
# Get the minimum area rectangle that encloses the contour.
rect = cv2.minAreaRect(cnt)
# Get the corners of the minimum area rectangle.
box = cv2.boxPoints(rect)
# Get the epsilon value, which is used to approximate the contour.
epsilon = cv2.arcLength(cnt, True)
# Approximate the contour using the epsilon value.
approx = cv2.approxPolyDP(cnt, 0.02 * epsilon, True)
# Get the points of the approximated contour.
points = np.squeeze(approx)
# If the number of points is not 4, then use the points of the minimum area rectangle.
if len(points) != 4:
points = box
# Get the top left, bottom right, bottom left, and top right corners of the bounding box.
tl, br, bl, tr = get_corners(points.tolist())
# Add the bounding box to the list of bounding boxes.
table_bboxes.append([tl, tr, br, bl])
# Return the list of bounding boxes.
return table_bboxes
def predict(image_input):
logger.info(f"Image input: {image_input}")
# Inference the tables in the image.
result = inference_detector(table_det, image_input)
# Get the masks of the tables.
mask_images = result.pred_instances.masks.cpu().numpy()
scores = result.pred_instances.scores.cpu().numpy()
bboxes = result.pred_instances.bboxes.cpu().numpy()
logger.info(f"Result: {result}")
bbox_list = []
# Filter out the masks with a score less than 0.5.
filtered_mask_images = mask_images[scores > 0.5]
filtered_bboxes = bboxes[scores > 0.5]
# Get the bounding boxes of the tables.
for mask_image in filtered_mask_images:
bbox_list.extend(get_bbox(mask_image.astype(np.uint8)))
return {'rect-fit': bbox_list, 'bbox': filtered_bboxes.tolist()}
def run():
demo = gr.Interface(
fn=predict,
inputs=gr.components.Image(),
outputs=gr.JSON(),
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()
|