|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
from matplotlib import pyplot as plt |
|
from mpl_toolkits.mplot3d import Axes3D |
|
|
|
def perform_kmeans_segmentation(image, number_of_clusters): |
|
|
|
pixel_vals = image.reshape((-1,3)) |
|
|
|
pixel_vals = np.float32(pixel_vals) |
|
|
|
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85) |
|
k = number_of_clusters |
|
retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) |
|
|
|
|
|
centers = np.uint8(centers) |
|
segmented_data = centers[labels.flatten()] |
|
|
|
|
|
segmented_image = segmented_data.reshape((image.shape)) |
|
return segmented_image, centers, labels |
|
|
|
|
|
def graph(image, centers, labels, display_centres): |
|
r, g, b = cv2.split(image) |
|
r = r.flatten() |
|
g = g.flatten() |
|
b = b.flatten() |
|
|
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection='3d') |
|
ax.scatter(r, g, b, s=10, c=labels, marker="o", label='image') |
|
if display_centres: |
|
ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=50, c='r', marker="o", label='cluster centers') |
|
return fig |
|
|
|
|
|
def resize_image(image): |
|
x = np.shape(image)[0] |
|
y = np.shape(image)[1] |
|
if (x <= 800) and (y <= 800): |
|
return image |
|
|
|
if x >= y: |
|
r = 800 / x |
|
dim = (int(y * r), 800) |
|
else: |
|
r = 800 / y |
|
dim = (800, int(x * r)) |
|
|
|
res = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) |
|
return res |
|
|
|
|
|
def kmeans_image(number_of_clusters, input_image, display_cluster_centres): |
|
image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) |
|
image = resize_image(image) |
|
|
|
segmented_image, cluster_centres, labels = perform_kmeans_segmentation(image, number_of_clusters) |
|
plot = graph(image, cluster_centres, labels, display_cluster_centres) |
|
|
|
return gr.Image(value=segmented_image), plot |
|
|
|
|
|
iface = gr.Interface( |
|
fn=kmeans_image, |
|
inputs=[gr.Number(minimum=2, maximum=10, value=2, label="Number of Clusters", info="Number of clusters to segment the image into. Enter a value between 2 and 10."), |
|
gr.Image(type="numpy", label="Image to Segment"), |
|
gr.Checkbox(label="display cluster centers")], |
|
outputs=["image", gr.Plot(visible=True, label="Image plotted in 3D space")], |
|
examples=[ |
|
[2, "images/simple_dots.png", False], |
|
[3, "images/purple_flower.jpg", False], |
|
[3, "images/rocky_coast.jpg", False], |
|
[2, "images/simple_flowers.jpg", False], |
|
]) |
|
iface.launch() |