import gradio as gr import numpy as np import cv2 from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D def perform_kmeans_segmentation(image, number_of_clusters): # Reshaping the image into a 2D array of pixels and 3 color values (RGB) pixel_vals = image.reshape((-1,3)) # Convert to float type pixel_vals = np.float32(pixel_vals) criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85) k = number_of_clusters retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) # convert data into 8-bit values centers = np.uint8(centers) segmented_data = centers[labels.flatten()] # reshape data into the original image dimensions segmented_image = segmented_data.reshape((image.shape)) return segmented_image, centers, labels def graph(image, centers, labels, display_centres): r, g, b = cv2.split(image) r = r.flatten() g = g.flatten() b = b.flatten() #plotting fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(r, g, b, s=10, c=labels, marker="o", label='image') if display_centres: ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=50, c='r', marker="o", label='cluster centers') return fig def resize_image(image): x = np.shape(image)[0] y = np.shape(image)[1] if (x <= 800) and (y <= 800): return image if x >= y: r = 800 / x dim = (int(y * r), 800) else: r = 800 / y dim = (800, int(x * r)) res = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) return res def kmeans_image(number_of_clusters, input_image, display_cluster_centres): image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) image = resize_image(image) segmented_image, cluster_centres, labels = perform_kmeans_segmentation(image, number_of_clusters) plot = graph(image, cluster_centres, labels, display_cluster_centres) return gr.Image(value=segmented_image), plot iface = gr.Interface( fn=kmeans_image, inputs=[gr.Number(minimum=2, maximum=10, value=2, label="Number of Clusters", info="Number of clusters to segment the image into. Enter a value between 2 and 10."), gr.Image(type="numpy", label="Image to Segment"), gr.Checkbox(label="display cluster centers")], outputs=["image", gr.Plot(visible=True, label="Image plotted in 3D space")], examples=[ [2, "images/simple_dots.png", False], [3, "images/purple_flower.jpg", False], [3, "images/rocky_coast.jpg", False], [2, "images/simple_flowers.jpg", False], ]) iface.launch()