Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import kornia as K | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from scipy.cluster.vq import kmeans | |
| from PIL import Image | |
| def get_coordinates_from_mask(mask_in): | |
| x_y = np.where(mask_in != [0,0,0,255])[:2] | |
| x_y = np.column_stack((x_y[1], x_y[0])) | |
| x_y = np.float32(x_y) | |
| centroids,_ = kmeans(x_y,4) | |
| centroids = np.int64(centroids) | |
| return centroids | |
| def get_top_bottom_coordinates(coords): | |
| top_coord = min(coords, key=lambda x : x[1]) | |
| bottom_coord = max(coords, key=lambda x : x[1]) | |
| return top_coord, bottom_coord | |
| def sort_centroids_clockwise(centroids: np.ndarray): | |
| c_list = centroids.tolist() | |
| c_list.sort(key = lambda y : y[0]) | |
| left_coords = c_list[:2] | |
| right_coords = c_list[-2:] | |
| top_left, bottom_left = get_top_bottom_coordinates(left_coords) | |
| top_right, bottom_right = get_top_bottom_coordinates(right_coords) | |
| return top_left, top_right, bottom_right, bottom_left | |
| def infer(image_input, dst_height: str, dst_width: str): | |
| if isinstance(image_input, dict): | |
| image_in = np.array(image_input['composite']) | |
| mask_in = np.array(image_input['layers'][0]) if image_input['layers'] else np.zeros_like(image_in) | |
| else: | |
| image_in = image_input | |
| mask_in = np.zeros_like(image_in) | |
| torch_img = K.utils.image_to_tensor(image_in).float() / 255.0 | |
| centroids = get_coordinates_from_mask(mask_in) | |
| ordered_src_coords = sort_centroids_clockwise(centroids) | |
| # the source points are the region to crop corners | |
| points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32) | |
| # the destination points are the image vertexes | |
| h, w = int(dst_height), int(dst_width) # destination size | |
| points_dst = torch.tensor([[ | |
| [0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.], | |
| ]], dtype=torch.float32) | |
| # compute perspective transform | |
| M: torch.tensor = K.geometry.transform.get_perspective_transform(points_src, points_dst) | |
| # warp the original image by the found transform | |
| torch_img = torch.stack([torch_img],) | |
| img_warp: torch.tensor = K.geometry.transform.warp_perspective(torch_img, M, dsize=(h, w)) | |
| # convert back to numpy | |
| img_np = K.utils.tensor_to_image(torch_img[0]) | |
| img_warp_np: np.ndarray = K.utils.tensor_to_image(img_warp[0]) | |
| # draw points into original image | |
| for i in range(4): | |
| center = tuple(points_src[0, i].long().numpy()) | |
| img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1) | |
| # create the plot | |
| fig, axs = plt.subplots(1, 2, figsize=(16, 10)) | |
| axs = axs.ravel() | |
| axs[0].axis('off') | |
| axs[0].set_title('image source') | |
| axs[0].imshow(img_np) | |
| axs[1].axis('off') | |
| axs[1].set_title('image destination') | |
| axs[1].imshow(img_warp_np) | |
| return fig | |
| description = """In this space you can warp an image using perspective transform with the Kornia library as seen in [this tutorial](https://kornia.github.io/tutorials/#category=Homography). | |
| 1. Upload an image or use the example provided | |
| 2. Set 4 points into the image using the brush tool, which define the area to warp | |
| 3. Set a desired output size (or go with the default) | |
| 4. Click Submit to run the demo | |
| """ | |
| # Load the example image | |
| example_image = Image.open("bruce.png") | |
| example_image_np = np.array(example_image) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Homography Warping") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| image_input = gr.ImageEditor( | |
| type="numpy", | |
| label="Input Image", | |
| brush=gr.Brush(colors=["#ff0000"], default_size=5), | |
| height=400, | |
| width=600 | |
| ) | |
| output_plot = gr.Plot(label="Output") | |
| with gr.Row(): | |
| dst_height = gr.Textbox(label="Destination Height", value="64") | |
| dst_width = gr.Textbox(label="Destination Width", value="128") | |
| submit_button = gr.Button("Submit") | |
| submit_button.click( | |
| fn=infer, | |
| inputs=[image_input, dst_height, dst_width], | |
| outputs=output_plot | |
| ) | |
| gr.Examples( | |
| examples=[[example_image_np, "64", "128"]], | |
| inputs=[image_input, dst_height, dst_width], | |
| outputs=output_plot, | |
| fn=infer, | |
| cache_examples=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |