import matplotlib matplotlib.use('Agg') import gradio as gr import tensorflow as tf from huggingface_hub import from_pretrained_keras import numpy as np from collections import defaultdict import matplotlib.pyplot as plt import plotly.express as px from plotly import subplots import pandas as pd import random (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() x_data = np.concatenate([x_train, x_test]) y_data = np.concatenate([y_train, y_test]) num_classes = 10 classes = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] clustering_model = from_pretrained_keras("keras-io/semantic-image-clustering") # Get the cluster probability distribution of the input images. clustering_probs = clustering_model.predict(x_data, batch_size=500, verbose=1) # Get the cluster of the highest probability. cluster_assignments = tf.math.argmax(clustering_probs, axis=-1).numpy() # Store the clustering confidence. # Images with the highest clustering confidence are considered the 'prototypes' # of the clusters. cluster_confidence = tf.math.reduce_max(clustering_probs, axis=-1).numpy() clusters = defaultdict(list) for idx, c in enumerate(cluster_assignments): clusters[c].append((idx, cluster_confidence[idx])) def get_cluster_size(cluster_number: int): cluster_size = len(clusters[cluster_number-1]) return f"Cluster #{cluster_number} consists of {cluster_size} objects" def get_images_from_cluster(cluster_number: int, num_images: int, image_mode: str): position = 1 if image_mode == "Random Images from Cluster": cluster_instances = clusters[cluster_number-1] random.shuffle(cluster_instances) else : cluster_instances = sorted(clusters[cluster_number-1], key=lambda kv: kv[1], reverse=True) fig = plt.figure() for j in range(num_images): image_idx = cluster_instances[j][0] plt.subplot(1, num_images, position) plt.imshow(x_data[image_idx].astype("uint8")) plt.title(classes[y_data[image_idx][0]]) plt.axis("off") position += 1 fig.tight_layout() return fig # labels = [] # images = [] # for j in range(num_images): # image_idx = cluster_instances[j][0] # images.append(x_data[image_idx].astype("uint8")) # labels.append(classes[y_data[image_idx][0]]) # fig = subplots.make_subplots(rows=int(num_images/4)+1, cols=4, subplot_titles=labels) # for j in range(num_images): # fig.add_trace(px.imshow(images[j]).data[0], row=int(j/4)+1, col=j%4+1) # fig.update_xaxes(visible=False) # fig.update_yaxes(visible=False) # return fig def get_cluster_details(cluster_number: int): cluster_label_counts = list() cluster_label_counts = [0] * num_classes instances = clusters[cluster_number-1] for i, _ in instances: cluster_label_counts[y_data[i][0]] += 1 class_count = zip(classes, cluster_label_counts) class_count_dict = dict(class_count) count_df = pd.Series(class_count_dict).to_frame() fig_pie = px.pie(count_df, values=0, names=count_df.index, title='Number of class objects in cluster') return fig_pie def get_cluster_info(cluster_number: int, num_images: int, image_mode: str): cluster_size = get_cluster_size(cluster_number) img_fig = get_images_from_cluster(cluster_number, num_images, image_mode) detail_fig = get_cluster_details(cluster_number) return [cluster_size, img_fig, detail_fig] article = """
Authors: Johannes Kolbe after an example by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) on **keras.io**""" description = """
# Semantic Image Clustering This space is intended to give you insights to image clusters, created by a model trained with the [**Semantic Clustering by Adopting Nearest neighbors (SCAN)**](https://arxiv.org/abs/2005.12320)(Van Gansbeke et al., 2020) algorithm. First choose one of the 20 clusters, and how many images you want to preview from it. There are two options for the images either *Random*, which as you might guess, gives you random images from the cluster or *High Similarity*, which gives you images that are similar according to the learned representations of the cluster. """ demo = gr.Blocks() with demo: gr.Markdown(description) with gr.Row(): btn = gr.Button("Get Cluster Info") with gr.Column(): inp = [gr.Slider(minimum=1, maximum=20, step=1, label="Select Cluster"), gr.Slider(minimum=6, maximum=15, step=1, label="Number of Images to Show", value=8), gr.Radio(["Random Images from Cluster", "High Similarity Images"], label="Image Choice")] with gr.Row(): with gr.Column(): out1 = [gr.Text(label="Cluster Size"), gr.Plot(label="Image Examples"), gr.Plot(label="Class details")] gr.Markdown(article) btn.click(fn=get_cluster_info, inputs=inp, outputs=out1) demo.launch()