File size: 5,224 Bytes
b4b75f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8716e2b
b4b75f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64dd9de
b4b75f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import matplotlib
matplotlib.use('Agg')
import gradio as gr
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import plotly.express as px
from plotly import subplots
import pandas as pd
import random

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])
num_classes = 10
classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

clustering_model = from_pretrained_keras("keras-io/semantic-image-clustering")

# Get the cluster probability distribution of the input images.
clustering_probs = clustering_model.predict(x_data, batch_size=500, verbose=1)
# Get the cluster of the highest probability.
cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
# Store the clustering confidence.
# Images with the highest clustering confidence are considered the 'prototypes'
# of the clusters.
cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()

clusters = defaultdict(list)
for idx, c in enumerate(cluster_assignments):
    clusters[c].append((idx, cluster_confidence[idx]))

def get_cluster_size(cluster_number: int):
    cluster_size = len(clusters[cluster_number-1])
    return f"Cluster #{cluster_number} consists of {cluster_size} objects"

def get_images_from_cluster(cluster_number: int, num_images: int, image_mode: str):
    position = 1

    if image_mode == "Random Images from Cluster":
        cluster_instances = clusters[cluster_number-1]
        random.shuffle(cluster_instances)
    else :
        cluster_instances = sorted(clusters[cluster_number-1], key=lambda kv: kv[1], reverse=True)
    fig = plt.figure()
    for j in range(num_images):
        image_idx = cluster_instances[j][0]
        plt.subplot(1, num_images, position)
        plt.imshow(x_data[image_idx].astype("uint8"))
        plt.title(classes[y_data[image_idx][0]])
        plt.axis("off")
        position += 1
    fig.tight_layout()
    return fig
    
    # labels = []
    # images = []
    # for j in range(num_images):
    #     image_idx = cluster_instances[j][0]
    #     images.append(x_data[image_idx].astype("uint8"))
    #     labels.append(classes[y_data[image_idx][0]])

    # fig = subplots.make_subplots(rows=int(num_images/4)+1, cols=4, subplot_titles=labels)
    # for j in range(num_images):
    #     fig.add_trace(px.imshow(images[j]).data[0], row=int(j/4)+1, col=j%4+1)

    # fig.update_xaxes(visible=False)  
    # fig.update_yaxes(visible=False)
    
    # return fig


def get_cluster_details(cluster_number: int):
    cluster_label_counts = list()

    cluster_label_counts = [0] * num_classes
    instances = clusters[cluster_number-1]
    for i, _ in instances:
        cluster_label_counts[y_data[i][0]] += 1
        
    class_count = zip(classes, cluster_label_counts)
    class_count_dict = dict(class_count)
    
    count_df = pd.Series(class_count_dict).to_frame()
    
    fig_pie = px.pie(count_df, values=0, names=count_df.index, title='Number of class objects in cluster')
    return fig_pie


def get_cluster_info(cluster_number: int, num_images: int, image_mode: str):
    cluster_size = get_cluster_size(cluster_number)
    img_fig = get_images_from_cluster(cluster_number, num_images, image_mode)
    detail_fig = get_cluster_details(cluster_number)
    
    return [cluster_size, img_fig, detail_fig]



article = """<center>

Authors: <a href='https://twitter.com/johko990' target='_blank'>Johannes Kolbe</a> after an example by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) on
<a href='https://keras.io/examples/vision/semantic_image_clustering/' target='_blank'>**keras.io**</a>"""



description = """<center> 

# Semantic Image Clustering

This space is intended to give you insights to image clusters, created by a model trained with the [**Semantic Clustering by Adopting Nearest neighbors (SCAN)**](https://arxiv.org/abs/2005.12320)(Van Gansbeke et al., 2020) algorithm.

First choose one of the 20 clusters, and how many images you want to preview from it. There are two options for the images either *Random*, which as you might guess, 
gives you random images from the cluster or *High Similarity*, which gives you images that are similar according to the learned representations of the cluster.
"""

    
demo = gr.Blocks()

with demo:
    gr.Markdown(description)
    with gr.Row():
        btn = gr.Button("Get Cluster Info")
        with gr.Column():
            inp = [gr.Slider(minimum=1, maximum=20, step=1, label="Select Cluster"),
                gr.Slider(minimum=6, maximum=15, step=1, label="Number of Images to Show", value=8),
                gr.Radio(["Random Images from Cluster", "High Similarity Images"], label="Image Choice")]
    with gr.Row():
        with gr.Column():
            out1 = [gr.Text(label="Cluster Size"), gr.Plot(label="Image Examples"), gr.Plot(label="Class details")]
    gr.Markdown(article)
    
    btn.click(fn=get_cluster_info, inputs=inp, outputs=out1)

demo.launch()