File size: 2,692 Bytes
3f7cf74 c3b2eaa 9106c85 3f7cf74 9106c85 c3b2eaa 7387ddd 9106c85 7387ddd 9106c85 7387ddd 9106c85 197375b 7387ddd 7bb1a95 67a2260 9106c85 7793bb7 2360e62 7793bb7 7bb1a95 9106c85 7793bb7 9106c85 7387ddd 6f9cba2 50c0336 c3b2eaa 3f7cf74 d6cc091 7bb1a95 de7343c c3b2eaa 7bb1a95 c3b2eaa 3f7cf74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import numpy as np
import cv2
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def perform_kmeans_segmentation(image, number_of_clusters):
# Reshaping the image into a 2D array of pixels and 3 color values (RGB)
pixel_vals = image.reshape((-1,3))
# Convert to float type
pixel_vals = np.float32(pixel_vals)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85)
k = number_of_clusters
retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
# convert data into 8-bit values
centers = np.uint8(centers)
segmented_data = centers[labels.flatten()]
# reshape data into the original image dimensions
segmented_image = segmented_data.reshape((image.shape))
return segmented_image, centers, labels
def graph(image, centers, labels, display_centres):
r, g, b = cv2.split(image)
r = r.flatten()
g = g.flatten()
b = b.flatten()
#plotting
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(r, g, b, s=10, c=labels, marker="o", label='image')
if display_centres:
ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=50, c='r', marker="o", label='cluster centers')
return fig
def resize_image(image):
x = np.shape(image)[0]
y = np.shape(image)[1]
if (x <= 800) and (y <= 800):
return image
if x >= y:
r = 800 / x
dim = (int(y * r), 800)
else:
r = 800 / y
dim = (800, int(x * r))
res = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
return res
def kmeans_image(number_of_clusters, input_image, display_cluster_centres):
image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
image = resize_image(image)
segmented_image, cluster_centres, labels = perform_kmeans_segmentation(image, number_of_clusters)
plot = graph(image, cluster_centres, labels, display_cluster_centres)
return gr.Image(value=segmented_image), plot
iface = gr.Interface(
fn=kmeans_image,
inputs=[gr.Number(minimum=2, maximum=10, value=2, label="Number of Clusters", info="Number of clusters to segment the image into. Enter a value between 2 and 10."),
gr.Image(type="numpy", label="Image to Segment"),
gr.Checkbox(label="display cluster centers")],
outputs=["image", gr.Plot(visible=True, label="Image plotted in 3D space")],
examples=[
[2, "images/simple_dots.png", False],
[3, "images/purple_flower.jpg", False],
[3, "images/rocky_coast.jpg", False],
[2, "images/simple_flowers.jpg", False],
])
iface.launch() |