File size: 4,013 Bytes
ebf587a
 
 
 
 
 
b8884cd
ebf587a
 
 
 
 
 
 
 
 
12a3b31
ebf587a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8884cd
ebf587a
 
b8884cd
ebf587a
 
 
 
 
 
e2db66d
ebf587a
 
 
 
b8884cd
ebf587a
 
b8884cd
ebf587a
 
b8884cd
 
ebf587a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5079be4
b8884cd
55c92cf
 
 
 
ebf587a
b8884cd
 
ebf587a
 
b8884cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import kornia as K
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.vq import kmeans

def get_coordinates_from_mask(mask_in):
    x_y = np.where(mask_in != [0,0,0,255])[:2]
    x_y = np.column_stack((x_y[1], x_y[0]))
    x_y = np.float32(x_y)
    centroids,_ = kmeans(x_y,4)
    centroids = np.int64(centroids)
    return centroids

def get_top_bottom_coordinates(coords):
    top_coord = min(coords, key=lambda x : x[1])
    bottom_coord = max(coords, key=lambda x : x[1])
    return top_coord, bottom_coord

def sort_centroids_clockwise(centroids: np.ndarray):    
    c_list = centroids.tolist()
    c_list.sort(key = lambda y : y[0])
    
    left_coords = c_list[:2]
    right_coords = c_list[-2:]
    
    top_left, bottom_left = get_top_bottom_coordinates(left_coords)
    top_right, bottom_right = get_top_bottom_coordinates(right_coords)
    
    return top_left, top_right, bottom_right, bottom_left

def infer(image_input, dst_height: str, dst_width: str):
    image_in = image_input["image"]
    mask_in = image_input["mask"]
    torch_img = K.utils.image_to_tensor(image_in).float() / 255.0
    
    centroids = get_coordinates_from_mask(mask_in)
    ordered_src_coords = sort_centroids_clockwise(centroids)
    # the source points are the region to crop corners
    points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32)
    # the destination points are the image vertexes
    h, w = int(dst_height), int(dst_width)  # destination size
    points_dst = torch.tensor([[
        [0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.],
    ]], dtype=torch.float32)
    # compute perspective transform
    M: torch.tensor = K.geometry.transform.get_perspective_transform(points_src, points_dst)
    # warp the original image by the found transform
    torch_img = torch.stack([torch_img],)
    img_warp: torch.tensor = K.geometry.transform.warp_perspective(torch_img, M, dsize=(h, w))
    
    # convert back to numpy
    img_np = K.utils.tensor_to_image(torch_img[0])
    img_warp_np: np.ndarray = K.utils.tensor_to_image(img_warp[0])
    # draw points into original image
    for i in range(4):
        center = tuple(points_src[0, i].long().numpy())
        img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1)
    # create the plot
    fig, axs = plt.subplots(1, 2, figsize=(16, 10))
    axs = axs.ravel()
    axs[0].axis('off')
    axs[0].set_title('image source')
    axs[0].imshow(img_np)
    axs[1].axis('off')
    axs[1].set_title('image destination')
    axs[1].imshow(img_warp_np)
    return fig

description = """In this space you can warp an image using perspective transform with the Kornia library as seen in [this tutorial](https://kornia.github.io/tutorials/#category=Homography).
1. Upload an image or use the example provided
2. Set 4 points into the image with your cursor, which define the area to warp
3. Set a desired output size (or go with the default)
4. Click Submit to run the demo
"""

example_mask = np.zeros((327, 600, 4), dtype=np.uint8)
example_mask[:, :, 3] = 255
example_image_dict = {"image": "bruce.png", "mask": example_mask}

with gr.Blocks() as demo:
    gr.Markdown("# Homography Warping")
    gr.Markdown(description)
    
    with gr.Row():
        image_input = gr.Image(tool="sketch", type="numpy", label="Input Image")
        output_plot = gr.Plot(label="Output")
    
    with gr.Row():
        dst_height = gr.Textbox(label="Destination Height", value="64")
        dst_width = gr.Textbox(label="Destination Width", value="128")
    
    submit_button = gr.Button("Submit")
    submit_button.click(
        fn=infer,
        inputs=[image_input, dst_height, dst_width],
        outputs=output_plot
    )
    
    gr.Examples(
        examples=[[example_image_dict, "64", "128"]],
        inputs=[image_input, dst_height, dst_width],
        outputs=output_plot,
        fn=infer,
        cache_examples=True
    )

if __name__ == "__main__":
    demo.launch()