Johannes Kolbe commited on
Commit
b4b75f2
1 Parent(s): 2ea85da

should work

Browse files
Files changed (2) hide show
  1. app.py +146 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import gradio as gr
4
+ import tensorflow as tf
5
+ from huggingface_hub import from_pretrained_keras
6
+ import numpy as np
7
+ from collections import defaultdict
8
+ import matplotlib.pyplot as plt
9
+ import plotly.express as px
10
+ from plotly import subplots
11
+ import pandas as pd
12
+ import random
13
+
14
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
15
+ x_data = np.concatenate([x_train, x_test])
16
+ y_data = np.concatenate([y_train, y_test])
17
+ num_classes = 10
18
+ classes = [
19
+ "airplane",
20
+ "automobile",
21
+ "bird",
22
+ "cat",
23
+ "deer",
24
+ "dog",
25
+ "frog",
26
+ "horse",
27
+ "ship",
28
+ "truck",
29
+ ]
30
+
31
+ clustering_model = from_pretrained_keras("johko/semantic-image-clustering")
32
+
33
+ # Get the cluster probability distribution of the input images.
34
+ clustering_probs = clustering_model.predict(x_data, batch_size=500, verbose=1)
35
+ # Get the cluster of the highest probability.
36
+ cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
37
+ # Store the clustering confidence.
38
+ # Images with the highest clustering confidence are considered the 'prototypes'
39
+ # of the clusters.
40
+ cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()
41
+
42
+ clusters = defaultdict(list)
43
+ for idx, c in enumerate(cluster_assignments):
44
+ clusters[c].append((idx, cluster_confidence[idx]))
45
+
46
+ def get_cluster_size(cluster_number: int):
47
+ cluster_size = len(clusters[cluster_number-1])
48
+ return f"Cluster #{cluster_number} consists of {cluster_size} objects"
49
+
50
+ def get_images_from_cluster(cluster_number: int, num_images: int, image_mode: str):
51
+ position = 1
52
+
53
+ if image_mode == "Random Images from Cluster":
54
+ cluster_instances = clusters[cluster_number-1]
55
+ random.shuffle(cluster_instances)
56
+ else :
57
+ cluster_instances = sorted(clusters[cluster_number-1], key=lambda kv: kv[1], reverse=True)
58
+ fig = plt.figure()
59
+ for j in range(num_images):
60
+ image_idx = cluster_instances[j][0]
61
+ plt.subplot(1, num_images, position)
62
+ plt.imshow(x_data[image_idx].astype("uint8"))
63
+ plt.title(classes[y_data[image_idx][0]])
64
+ plt.axis("off")
65
+ position += 1
66
+ fig.tight_layout()
67
+ return fig
68
+
69
+ # labels = []
70
+ # images = []
71
+ # for j in range(num_images):
72
+ # image_idx = cluster_instances[j][0]
73
+ # images.append(x_data[image_idx].astype("uint8"))
74
+ # labels.append(classes[y_data[image_idx][0]])
75
+
76
+ # fig = subplots.make_subplots(rows=int(num_images/4)+1, cols=4, subplot_titles=labels)
77
+ # for j in range(num_images):
78
+ # fig.add_trace(px.imshow(images[j]).data[0], row=int(j/4)+1, col=j%4+1)
79
+
80
+ # fig.update_xaxes(visible=False)
81
+ # fig.update_yaxes(visible=False)
82
+
83
+ # return fig
84
+
85
+
86
+ def get_cluster_details(cluster_number: int):
87
+ cluster_label_counts = list()
88
+
89
+ cluster_label_counts = [0] * num_classes
90
+ instances = clusters[cluster_number-1]
91
+ for i, _ in instances:
92
+ cluster_label_counts[y_data[i][0]] += 1
93
+
94
+ class_count = zip(classes, cluster_label_counts)
95
+ class_count_dict = dict(class_count)
96
+
97
+ count_df = pd.Series(class_count_dict).to_frame()
98
+
99
+ fig_pie = px.pie(count_df, values=0, names=count_df.index, title='Number of class objects in cluster')
100
+ return fig_pie
101
+
102
+
103
+ def get_cluster_info(cluster_number: int, num_images: int, image_mode: str):
104
+ cluster_size = get_cluster_size(cluster_number)
105
+ img_fig = get_images_from_cluster(cluster_number, num_images, image_mode)
106
+ detail_fig = get_cluster_details(cluster_number)
107
+
108
+ return [cluster_size, img_fig, detail_fig]
109
+
110
+
111
+
112
+ article = """<center>
113
+ Authors: <a href='https://twitter.com/johko990' target='_blank'>Johannes Kolbe</a> after an example by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) on
114
+ <a href='https://keras.io/examples/vision/semantic_image_clustering/' target='_blank'>**keras.io**</a>"""
115
+
116
+
117
+
118
+ description = """<center>
119
+
120
+ # Semantic Image Clustering
121
+
122
+ This space is intended to give you insights to image clusters, created by a model trained with the [**Semantic Clustering by Adopting Nearest neighbors (SCAN)**](https://arxiv.org/abs/2005.12320)(Van Gansbeke et al., 2020) algorithm.
123
+
124
+ First choose one of the 20 clusters, and how many images you want to preview from it. There are two options for the images either *Random*, which as you might guess,
125
+ gives you random images from the cluster or *High Similarity*, which gives you images that are similar according to the learned representations of the cluster.
126
+ """
127
+
128
+
129
+ demo = gr.Blocks()
130
+
131
+ with demo:
132
+ gr.Markdown(description)
133
+ with gr.Row():
134
+ btn = gr.Button("Get Cluster Info")
135
+ with gr.Column():
136
+ inp = [gr.Slider(minimum=1, maximum=20, step=1, label="Select Cluster"),
137
+ gr.Slider(minimum=6, maximum=15, step=1, label="Number of Images to Show", value=8),
138
+ gr.Radio(["Random Images from Cluster", "High Similarity Images"], label="Image Choice")]
139
+ with gr.Row():
140
+ with gr.Column():
141
+ out1 = [gr.Text(label="Cluster Size"), gr.Plot(label="Image Examples"), gr.Plot(label="Class details")]
142
+ gr.Markdown(article)
143
+
144
+ btn.click(fn=get_cluster_info, inputs=inp, outputs=out1)
145
+
146
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow >=2.6.0
2
+ gradio == 3.0.12
3
+ huggingface_hub
4
+ jinja2
5
+ matplotlib
6
+ plotly
7
+ pandas
8
+ random
9
+ numpy
10
+ matplotlib