File size: 3,827 Bytes
b7f49b8
 
 
 
 
 
a03e1e7
 
 
 
b7f49b8
 
 
 
 
 
b01f517
 
b7f49b8
 
 
 
 
 
 
b01f517
b7f49b8
 
 
b01f517
b7f49b8
 
 
b01f517
b7f49b8
 
 
b01f517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7f49b8
a03e1e7
1e41e52
b01f517
 
 
 
 
ca40c0f
1e41e52
a03e1e7
 
ca40c0f
00196b8
b01f517
00196b8
1e41e52
 
567c46b
00196b8
1e41e52
00196b8
 
ba7e7bf
b01f517
b7f49b8
 
 
 
 
e1ed59b
b7f49b8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from mmdet.apis import init_detector, inference_detector
import gradio as gr
import cv2
import sys
import torch
import numpy as np
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print('Loading model...')
device = 'gpu' if torch.cuda.is_available() else 'cpu'

table_det = init_detector('model/table-det/config.py',
                          'model/table-det/model.pth', device=device)


def get_corners(points):
    """
    Returns the top-left, top-right, bottom-right, and bottom-left corners
    of a rectangle defined by a list of four points in the form of tuples.
    """
    # Sort points by x-coordinate
    sorted_points = sorted(points, key=lambda p: p[0])

    # Split sorted points into left and right halves
    left_points = sorted_points[:2]
    right_points = sorted_points[2:]

    # Sort left and right points by y-coordinate
    left_points = sorted(left_points, key=lambda p: p[1])
    right_points = sorted(right_points, key=lambda p: p[1], reverse=True)

    # Return corners in order: top-left, top-right, bottom-right, bottom-left
    return (left_points[0], right_points[0], right_points[1], left_points[1])


def get_bbox(mask_array):
    """
    Gets the bounding boxes of tables in a mask array.

    Args:
        mask_array (numpy.ndarray): The mask array to be processed.

    Returns:
        list[tuple(int, int, int, int)]: A list of bounding boxes, where each bounding box is a tuple of (top left x, top left y, bottom right x, bottom right y).
    """

    # Find the contours in the mask array.
    contours, hierarchy = cv2.findContours(
        mask_array, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # For each contour, get the bounding box.
    table_bboxes = []
    for cnt in contours:

        # Get the minimum area rectangle that encloses the contour.
        rect = cv2.minAreaRect(cnt)

        # Get the corners of the minimum area rectangle.
        box = cv2.boxPoints(rect)

        # Get the epsilon value, which is used to approximate the contour.
        epsilon = cv2.arcLength(cnt, True)

        # Approximate the contour using the epsilon value.
        approx = cv2.approxPolyDP(cnt, 0.02 * epsilon, True)

        # Get the points of the approximated contour.
        points = np.squeeze(approx)

        # If the number of points is not 4, then use the points of the minimum area rectangle.
        if len(points) != 4:
            points = box

        # Get the top left, bottom right, bottom left, and top right corners of the bounding box.
        tl, br, bl, tr = get_corners(points.tolist())

        # Add the bounding box to the list of bounding boxes.
        table_bboxes.append([tl, tr, br, bl])

    # Return the list of bounding boxes.
    return table_bboxes


def predict(image_input):
    logger.info(f"Image input: {image_input}")

    # Inference the tables in the image.
    result = inference_detector(table_det, image_input)

    # Get the masks of the tables.
    mask_images = result.pred_instances.masks.cpu().numpy()
    scores = result.pred_instances.scores.cpu().numpy()
    bboxes = result.pred_instances.bboxes.cpu().numpy()
    
    logger.info(f"Result: {result}")

    bbox_list = []

    # Filter out the masks with a score less than 0.5.
    filtered_mask_images = mask_images[scores > 0.5]
    filtered_bboxes = bboxes[scores > 0.5]

    # Get the bounding boxes of the tables.
    for mask_image in filtered_mask_images:
        bbox_list.extend(get_bbox(mask_image.astype(np.uint8)))

    return {'rect-fit': bbox_list, 'bbox': filtered_bboxes.tolist()}


def run():
    demo = gr.Interface(
        fn=predict,
        inputs=gr.components.Image(),
        outputs=gr.JSON(),
    )

    demo.launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    run()