Celine
Added additional info
d6cc091
import gradio as gr
import numpy as np
import cv2
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def perform_kmeans_segmentation(image, number_of_clusters):
# Reshaping the image into a 2D array of pixels and 3 color values (RGB)
pixel_vals = image.reshape((-1,3))
# Convert to float type
pixel_vals = np.float32(pixel_vals)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85)
k = number_of_clusters
retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
# convert data into 8-bit values
centers = np.uint8(centers)
segmented_data = centers[labels.flatten()]
# reshape data into the original image dimensions
segmented_image = segmented_data.reshape((image.shape))
return segmented_image, centers, labels
def graph(image, centers, labels, display_centres):
r, g, b = cv2.split(image)
r = r.flatten()
g = g.flatten()
b = b.flatten()
#plotting
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(r, g, b, s=10, c=labels, marker="o", label='image')
if display_centres:
ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=50, c='r', marker="o", label='cluster centers')
return fig
def resize_image(image):
x = np.shape(image)[0]
y = np.shape(image)[1]
if (x <= 800) and (y <= 800):
return image
if x >= y:
r = 800 / x
dim = (int(y * r), 800)
else:
r = 800 / y
dim = (800, int(x * r))
res = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
return res
def kmeans_image(number_of_clusters, input_image, display_cluster_centres):
image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
image = resize_image(image)
segmented_image, cluster_centres, labels = perform_kmeans_segmentation(image, number_of_clusters)
plot = graph(image, cluster_centres, labels, display_cluster_centres)
return gr.Image(value=segmented_image), plot
iface = gr.Interface(
fn=kmeans_image,
inputs=[gr.Number(minimum=2, maximum=10, value=2, label="Number of Clusters", info="Number of clusters to segment the image into. Enter a value between 2 and 10."),
gr.Image(type="numpy", label="Image to Segment"),
gr.Checkbox(label="display cluster centers")],
outputs=["image", gr.Plot(visible=True, label="Image plotted in 3D space")],
examples=[
[2, "images/simple_dots.png", False],
[3, "images/purple_flower.jpg", False],
[3, "images/rocky_coast.jpg", False],
[2, "images/simple_flowers.jpg", False],
])
iface.launch()