File size: 2,692 Bytes
3f7cf74
 
c3b2eaa
9106c85
 
3f7cf74
9106c85
c3b2eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7387ddd
9106c85
 
7387ddd
9106c85
 
 
7387ddd
 
9106c85
197375b
7387ddd
7bb1a95
 
67a2260
9106c85
 
7793bb7
 
 
 
 
 
2360e62
 
 
 
 
 
 
 
7793bb7
 
 
7bb1a95
9106c85
7793bb7
9106c85
7387ddd
 
6f9cba2
50c0336
c3b2eaa
3f7cf74
 
 
d6cc091
 
7bb1a95
de7343c
c3b2eaa
7bb1a95
 
 
 
c3b2eaa
3f7cf74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import numpy as np
import cv2
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def perform_kmeans_segmentation(image, number_of_clusters):
    # Reshaping the image into a 2D array of pixels and 3 color values (RGB)
    pixel_vals = image.reshape((-1,3))
    # Convert to float type
    pixel_vals = np.float32(pixel_vals)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85)
    k = number_of_clusters
    retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)

    # convert data into 8-bit values
    centers = np.uint8(centers)
    segmented_data = centers[labels.flatten()]
    
    # reshape data into the original image dimensions
    segmented_image = segmented_data.reshape((image.shape))
    return segmented_image, centers, labels


def graph(image, centers, labels, display_centres):
    r, g, b = cv2.split(image)
    r = r.flatten()
    g = g.flatten()
    b = b.flatten()
    #plotting 
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(r, g, b, s=10, c=labels, marker="o", label='image')
    if display_centres:
        ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=50, c='r', marker="o", label='cluster centers')
    return fig


def resize_image(image):
    x = np.shape(image)[0]
    y = np.shape(image)[1]
    if (x <= 800) and (y <= 800):
        return image

    if x >= y:
        r = 800 / x
        dim = (int(y * r), 800)
    else:
        r = 800 / y
        dim = (800, int(x * r))

    res = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
    return res


def kmeans_image(number_of_clusters, input_image, display_cluster_centres):
    image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
    image = resize_image(image)

    segmented_image, cluster_centres, labels = perform_kmeans_segmentation(image, number_of_clusters)
    plot = graph(image, cluster_centres, labels, display_cluster_centres)

    return gr.Image(value=segmented_image), plot


iface = gr.Interface(
    fn=kmeans_image, 
    inputs=[gr.Number(minimum=2, maximum=10, value=2, label="Number of Clusters", info="Number of clusters to segment the image into. Enter a value between 2 and 10."), 
            gr.Image(type="numpy", label="Image to Segment"),
            gr.Checkbox(label="display cluster centers")], 
    outputs=["image", gr.Plot(visible=True, label="Image plotted in 3D space")],
    examples=[
        [2, "images/simple_dots.png", False],
        [3, "images/purple_flower.jpg", False],
        [3, "images/rocky_coast.jpg", False],
        [2, "images/simple_flowers.jpg", False],
    ])
iface.launch()