Johannes Kolbe
footer change
64dd9de
raw
history blame contribute delete
No virus
5.22 kB
import matplotlib
matplotlib.use('Agg')
import gradio as gr
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import plotly.express as px
from plotly import subplots
import pandas as pd
import random
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])
num_classes = 10
classes = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
clustering_model = from_pretrained_keras("keras-io/semantic-image-clustering")
# Get the cluster probability distribution of the input images.
clustering_probs = clustering_model.predict(x_data, batch_size=500, verbose=1)
# Get the cluster of the highest probability.
cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy()
# Store the clustering confidence.
# Images with the highest clustering confidence are considered the 'prototypes'
# of the clusters.
cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy()
clusters = defaultdict(list)
for idx, c in enumerate(cluster_assignments):
clusters[c].append((idx, cluster_confidence[idx]))
def get_cluster_size(cluster_number: int):
cluster_size = len(clusters[cluster_number-1])
return f"Cluster #{cluster_number} consists of {cluster_size} objects"
def get_images_from_cluster(cluster_number: int, num_images: int, image_mode: str):
position = 1
if image_mode == "Random Images from Cluster":
cluster_instances = clusters[cluster_number-1]
random.shuffle(cluster_instances)
else :
cluster_instances = sorted(clusters[cluster_number-1], key=lambda kv: kv[1], reverse=True)
fig = plt.figure()
for j in range(num_images):
image_idx = cluster_instances[j][0]
plt.subplot(1, num_images, position)
plt.imshow(x_data[image_idx].astype("uint8"))
plt.title(classes[y_data[image_idx][0]])
plt.axis("off")
position += 1
fig.tight_layout()
return fig
# labels = []
# images = []
# for j in range(num_images):
# image_idx = cluster_instances[j][0]
# images.append(x_data[image_idx].astype("uint8"))
# labels.append(classes[y_data[image_idx][0]])
# fig = subplots.make_subplots(rows=int(num_images/4)+1, cols=4, subplot_titles=labels)
# for j in range(num_images):
# fig.add_trace(px.imshow(images[j]).data[0], row=int(j/4)+1, col=j%4+1)
# fig.update_xaxes(visible=False)
# fig.update_yaxes(visible=False)
# return fig
def get_cluster_details(cluster_number: int):
cluster_label_counts = list()
cluster_label_counts = [0] * num_classes
instances = clusters[cluster_number-1]
for i, _ in instances:
cluster_label_counts[y_data[i][0]] += 1
class_count = zip(classes, cluster_label_counts)
class_count_dict = dict(class_count)
count_df = pd.Series(class_count_dict).to_frame()
fig_pie = px.pie(count_df, values=0, names=count_df.index, title='Number of class objects in cluster')
return fig_pie
def get_cluster_info(cluster_number: int, num_images: int, image_mode: str):
cluster_size = get_cluster_size(cluster_number)
img_fig = get_images_from_cluster(cluster_number, num_images, image_mode)
detail_fig = get_cluster_details(cluster_number)
return [cluster_size, img_fig, detail_fig]
article = """<center>
Authors: <a href='https://twitter.com/johko990' target='_blank'>Johannes Kolbe</a> after an example by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) on
<a href='https://keras.io/examples/vision/semantic_image_clustering/' target='_blank'>**keras.io**</a>"""
description = """<center>
# Semantic Image Clustering
This space is intended to give you insights to image clusters, created by a model trained with the [**Semantic Clustering by Adopting Nearest neighbors (SCAN)**](https://arxiv.org/abs/2005.12320)(Van Gansbeke et al., 2020) algorithm.
First choose one of the 20 clusters, and how many images you want to preview from it. There are two options for the images either *Random*, which as you might guess,
gives you random images from the cluster or *High Similarity*, which gives you images that are similar according to the learned representations of the cluster.
"""
demo = gr.Blocks()
with demo:
gr.Markdown(description)
with gr.Row():
btn = gr.Button("Get Cluster Info")
with gr.Column():
inp = [gr.Slider(minimum=1, maximum=20, step=1, label="Select Cluster"),
gr.Slider(minimum=6, maximum=15, step=1, label="Number of Images to Show", value=8),
gr.Radio(["Random Images from Cluster", "High Similarity Images"], label="Image Choice")]
with gr.Row():
with gr.Column():
out1 = [gr.Text(label="Cluster Size"), gr.Plot(label="Image Examples"), gr.Plot(label="Class details")]
gr.Markdown(article)
btn.click(fn=get_cluster_info, inputs=inp, outputs=out1)
demo.launch()