Spaces:
Sleeping
Sleeping
Maximilian Noichl
commited on
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""01_clustering_methods.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colaboratory.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1mqAGInsaItbKYVUlP9muYz3fpdGBWFz5
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
import sklearn.cluster as cluster
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import colormaps as cmaps
|
| 20 |
+
import opinionated
|
| 21 |
+
plt.style.use("opinionated_rc")
|
| 22 |
+
from opinionated.core import download_googlefont
|
| 23 |
+
download_googlefont('Quicksand', add_to_cache=True)
|
| 24 |
+
plt.rc('font', family='Quicksand')
|
| 25 |
+
|
| 26 |
+
!wget https://github.com/scikit-learn-contrib/hdbscan/raw/master/notebooks/clusterable_data.npy
|
| 27 |
+
!wget https://github.com/mwaskom/seaborn-data/raw/master/penguins.csv
|
| 28 |
+
|
| 29 |
+
hdbscan_example_data = np.load('clusterable_data.npy')
|
| 30 |
+
penguins_dataset = pd.read_csv('penguins.csv')[['bill_length_mm','bill_depth_mm','flipper_length_mm']].dropna().values
|
| 31 |
+
|
| 32 |
+
from sklearn.preprocessing import StandardScaler
|
| 33 |
+
|
| 34 |
+
scaler = StandardScaler()
|
| 35 |
+
penguins_dataset_standardized = scaler.fit_transform(penguins_dataset)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
import gradio as gr
|
| 50 |
+
import numpy as np
|
| 51 |
+
import matplotlib.pyplot as plt
|
| 52 |
+
from sklearn.datasets import make_blobs, make_moons, load_iris
|
| 53 |
+
import seaborn as sns
|
| 54 |
+
import pandas as pd
|
| 55 |
+
import matplotlib.colors as mcolors
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
from sklearn.cluster import KMeans
|
| 59 |
+
from sklearn.cluster import AgglomerativeClustering
|
| 60 |
+
from sklearn.mixture import GaussianMixture
|
| 61 |
+
import hdbscan
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
import genieclust
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Pre-defined datasets
|
| 71 |
+
blobs_X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
|
| 72 |
+
moons_X, _ = make_moons(n_samples=300, noise=0.05, random_state=0)
|
| 73 |
+
|
| 74 |
+
# Penguins dataset (3D example)
|
| 75 |
+
# For the purpose of this example, let's simulate the Penguins dataset with iris for simplicity
|
| 76 |
+
iris_X, _ = load_iris(return_X_y=True)
|
| 77 |
+
# Assuming iris_X to be a placeholder for the Penguins dataset with numerical features
|
| 78 |
+
|
| 79 |
+
datasets = {
|
| 80 |
+
"Blobs": blobs_X,
|
| 81 |
+
"Moons": moons_X,
|
| 82 |
+
"Penguins": penguins_dataset_standardized, # Placeholder for Penguins dataset
|
| 83 |
+
"hDBSCAN sample": hdbscan_example_data
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Function for plotting the unclustered dataset
|
| 87 |
+
def plot_unclustered(dataset_name):
|
| 88 |
+
X = datasets[dataset_name] # Fetch dataset from the dictionary
|
| 89 |
+
|
| 90 |
+
# Check if the dataset has more than 2 dimensions
|
| 91 |
+
if X.shape[1] > 2:
|
| 92 |
+
# Convert dataset to DataFrame for seaborn pairplot
|
| 93 |
+
df = pd.DataFrame(X)
|
| 94 |
+
fig = sns.pairplot(df, plot_kws={'color': 'grey','alpha':0.7}, diag_kws={'color': 'grey'}).fig
|
| 95 |
+
else:
|
| 96 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 97 |
+
ax.scatter(X[:, 0], X[:, 1], color='gray', marker='.',alpha=.7)
|
| 98 |
+
ax.set_xlabel("Feature 1")
|
| 99 |
+
ax.set_ylabel("Feature 2")
|
| 100 |
+
ax.grid(True)
|
| 101 |
+
plt.tight_layout()
|
| 102 |
+
plt.close(fig)
|
| 103 |
+
|
| 104 |
+
return fig
|
| 105 |
+
|
| 106 |
+
def plot_clustered(dataset_name, clustering_method, kmeans_n_clusters, agg_n_clusters, agg_linkage, gmm_n_clusters, covariance_type,
|
| 107 |
+
genie_n_clusters, gini_threshold, M,hdbscan_min_cluster_size, hdbscan_min_samples):
|
| 108 |
+
X = datasets[dataset_name]
|
| 109 |
+
|
| 110 |
+
# Determine the clustering method and fit the model accordingly
|
| 111 |
+
if clustering_method == "K-Means":
|
| 112 |
+
model = KMeans(n_clusters=kmeans_n_clusters)
|
| 113 |
+
model.fit(X)
|
| 114 |
+
labels = model.labels_ # For K-Means, labels are in .labels_
|
| 115 |
+
|
| 116 |
+
elif clustering_method == "Agglomerative":
|
| 117 |
+
model = AgglomerativeClustering(n_clusters=agg_n_clusters, linkage=agg_linkage)
|
| 118 |
+
model.fit(X)
|
| 119 |
+
labels = model.labels_ # For Agglomerative Clustering, labels are in .labels_
|
| 120 |
+
|
| 121 |
+
elif clustering_method == "Gaussian Mixture":
|
| 122 |
+
model = GaussianMixture(n_components=gmm_n_clusters, covariance_type=covariance_type)
|
| 123 |
+
model.fit(X)
|
| 124 |
+
labels = model.predict(X) # For Gaussian Mixture, use .predict() to get labels
|
| 125 |
+
|
| 126 |
+
elif clustering_method == "Genie":
|
| 127 |
+
model = genieclust.Genie(n_clusters=genie_n_clusters, gini_threshold=gini_threshold, M=M)
|
| 128 |
+
labels = model.fit_predict(X) # GenieClust uses fit_predict directly for both fitting and label prediction
|
| 129 |
+
|
| 130 |
+
elif clustering_method == "h-DBSCAN":
|
| 131 |
+
clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, min_samples=hdbscan_min_samples).fit(X)
|
| 132 |
+
labels = clusterer.labels_
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
n_clusters= len(np.unique([x for x in labels if x >= 0]))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if n_clusters <= 10:
|
| 140 |
+
original_cmap = cmaps.greenorange_12
|
| 141 |
+
colors = original_cmap([x for x in range(n_clusters)])
|
| 142 |
+
# Create a new listed colormap with the extracted colors
|
| 143 |
+
new_cmap = mcolors.ListedColormap(colors)
|
| 144 |
+
else:
|
| 145 |
+
new_cmap = cmaps.cet_g_bw_minc
|
| 146 |
+
|
| 147 |
+
cluster_colors = [new_cmap(x) if x >= 0
|
| 148 |
+
else (0.5, 0.5, 0.5)
|
| 149 |
+
for x in labels]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Check if the dataset has more than 2 dimensions
|
| 153 |
+
if X.shape[1] > 2:
|
| 154 |
+
# Convert dataset to DataFrame for seaborn pairplot
|
| 155 |
+
df = pd.DataFrame(X)
|
| 156 |
+
# df['cluster'] = labels
|
| 157 |
+
# fig = sns.pairplot(df, color = cluster_colors, cmap=new_cmap).fig
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Create bins for each variable
|
| 161 |
+
n_bins = 10
|
| 162 |
+
bins = {column: np.linspace(df[column].min(), df[column].max(), n_bins+1) for column in df.columns}
|
| 163 |
+
|
| 164 |
+
# Create a figure and axes
|
| 165 |
+
n = len(df.columns)
|
| 166 |
+
fig, axes = plt.subplots(nrows=n, ncols=n, figsize=(n*2.3, n*2.3))
|
| 167 |
+
|
| 168 |
+
for i in range(n):
|
| 169 |
+
for j in range(n):
|
| 170 |
+
ax = axes[i, j]
|
| 171 |
+
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
|
| 172 |
+
|
| 173 |
+
if i != j:
|
| 174 |
+
ax.scatter(df[df.columns[j]], df[df.columns[i]], c=cluster_colors, alpha=0.8, marker='o',s = 10)
|
| 175 |
+
else: # Diagonal - Stacked Bar Charts
|
| 176 |
+
data = df[df.columns[i]]
|
| 177 |
+
counts = np.zeros((n_bins, n_clusters))
|
| 178 |
+
for cluster in range(n_clusters):
|
| 179 |
+
cluster_data = data[labels == cluster]
|
| 180 |
+
hist, _ = np.histogram(cluster_data, bins=bins[df.columns[i]])
|
| 181 |
+
counts[:, cluster] = hist
|
| 182 |
+
for cluster in range(n_clusters):
|
| 183 |
+
ax.bar(range(n_bins), counts[:, cluster], width=1, align='center',
|
| 184 |
+
bottom=np.sum(counts[:, :cluster], axis=1), color=cluster_colors[list(labels).index(cluster)] )
|
| 185 |
+
|
| 186 |
+
# Explicit axis lines at the bottom and left
|
| 187 |
+
ax.spines['top'].set_visible(False)
|
| 188 |
+
ax.spines['right'].set_visible(False)
|
| 189 |
+
ax.spines['bottom'].set_visible(True)
|
| 190 |
+
ax.spines['left'].set_visible(True)
|
| 191 |
+
|
| 192 |
+
# Hide axis marks for inner plots and adjust label size
|
| 193 |
+
if i < n - 1:
|
| 194 |
+
ax.tick_params(labelbottom=False) # Hide x-axis labels for all but bottom row
|
| 195 |
+
else:
|
| 196 |
+
ax.tick_params(axis='x', labelsize=8) # Smaller labels for x-axis
|
| 197 |
+
if j > 0:
|
| 198 |
+
ax.tick_params(labelleft=False) # Hide y-axis labels for all but first column
|
| 199 |
+
else:
|
| 200 |
+
ax.tick_params(axis='y', labelsize=8) # Smaller labels for y-axis
|
| 201 |
+
|
| 202 |
+
# Set labels for outer plots only
|
| 203 |
+
if i == n - 1:
|
| 204 |
+
ax.set_xlabel(df.columns[j], rotation=0, fontsize=12)
|
| 205 |
+
if j == 0:
|
| 206 |
+
ax.set_ylabel(df.columns[i], fontsize=12)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 213 |
+
ax.scatter(X[:, 0], X[:, 1], c=cluster_colors, marker='.')
|
| 214 |
+
ax.grid(True)
|
| 215 |
+
plt.tight_layout()
|
| 216 |
+
plt.close(fig)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
return fig
|
| 221 |
+
|
| 222 |
+
intro_md = """
|
| 223 |
+
# Cluster-algorithm-explorer
|
| 224 |
+
|
| 225 |
+
_by [Max Noichl](https://homepage.univie.ac.at/maximilian.noichl/), for the clustering & data-visualization-workshop, Bremen, 2024_
|
| 226 |
+
|
| 227 |
+
Below you can test a number of clustering-algorithms on several easier and harder datasets.
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# Gradio interface setup remains the same
|
| 234 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 235 |
+
with gr.Column():
|
| 236 |
+
gr.Markdown(intro_md)
|
| 237 |
+
with gr.Row():
|
| 238 |
+
|
| 239 |
+
with gr.Column():
|
| 240 |
+
gr.Markdown("# Choose your dataset:")
|
| 241 |
+
dataset_dropdown = gr.Dropdown(label="Select a dataset", choices=list(datasets.keys()), value="Blobs")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
gr.Markdown("# Choose your Clustering algorithm & Parameters:")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Update the dropdown for clustering method to include "Genie"
|
| 249 |
+
clustering_method_dropdown = gr.Dropdown(label="Select a clustering method", choices=["K-Means", "Agglomerative", "Gaussian Mixture", "Genie", "h-DBSCAN"], value="K-Means")
|
| 250 |
+
|
| 251 |
+
# K-Means parameters
|
| 252 |
+
with gr.Group(visible=True) as kmeans_params_group:
|
| 253 |
+
kmeans_n_clusters_slider = gr.Slider(minimum=2, maximum=10, step=1, label="Number of Clusters (K-Means)", value=4)
|
| 254 |
+
|
| 255 |
+
# Agglomerative Clustering parameters
|
| 256 |
+
with gr.Group(visible=False) as agglomerative_params_group:
|
| 257 |
+
agg_n_clusters_slider = gr.Slider(minimum=2, maximum=10, step=1, label="Number of Clusters (Agglomerative)", value=4)
|
| 258 |
+
agg_linkage_dropdown = gr.Dropdown(label="Linkage Type", choices=["ward", "complete", "average", "single"], value="ward")
|
| 259 |
+
|
| 260 |
+
# Gaussian Mixture Model parameters
|
| 261 |
+
with gr.Group(visible=False) as gmm_params_group:
|
| 262 |
+
gmm_n_clusters_slider = gr.Slider(minimum=2, maximum=10, step=1, label="Number of Components (GMM)", value=4)
|
| 263 |
+
covariance_type_dropdown = gr.Dropdown(label="Covariance Type", choices=["full", "tied", "diag", "spherical"], value="full")
|
| 264 |
+
|
| 265 |
+
# GenieClust parameters
|
| 266 |
+
with gr.Group(visible=False) as genie_params_group:
|
| 267 |
+
genie_n_clusters_slider = gr.Slider(minimum=2, maximum=10, step=1, label="Number of Clusters (Genie)", value=4)
|
| 268 |
+
gini_threshold_slider = gr.Slider(minimum=0.0, maximum=1.05, step=0.05, label="Gini Threshold (Genie)", value=.3)
|
| 269 |
+
M_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, label="M Parameter (Genie)", value=1.0)
|
| 270 |
+
|
| 271 |
+
with gr.Group(visible=False) as hdbscan_params_group:
|
| 272 |
+
hdbscan_min_cluster_size = gr.Slider(minimum=2, maximum=200, step=1, label="Minimal Cluster Size (hDBSCAN)", value=10)
|
| 273 |
+
hdbscan_min_samples = gr.Slider(minimum=2, maximum=200, step=1, label="Min. Samples (hDBSCAN)", value=10)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# Update the function that changes visible parameter groups based on selected clustering method
|
| 278 |
+
def update_method_params(clustering_method):
|
| 279 |
+
return {
|
| 280 |
+
kmeans_params_group: gr.Group(visible=clustering_method == "K-Means"),
|
| 281 |
+
agglomerative_params_group: gr.Group(visible=clustering_method == "Agglomerative"),
|
| 282 |
+
gmm_params_group: gr.Group(visible=clustering_method == "Gaussian Mixture"),
|
| 283 |
+
genie_params_group: gr.Group(visible=clustering_method == "Genie"),
|
| 284 |
+
hdbscan_params_group: gr.Group(visible=clustering_method == "h-DBSCAN"),
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
clustering_method_dropdown.change(update_method_params, inputs=[clustering_method_dropdown], outputs=[kmeans_params_group, agglomerative_params_group,
|
| 291 |
+
gmm_params_group, genie_params_group,hdbscan_params_group])
|
| 292 |
+
|
| 293 |
+
button = gr.Button("Run Clustering!")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
with gr.Column():
|
| 297 |
+
unclustered_plot_output = gr.Plot(label=None)
|
| 298 |
+
clustered_plot_output = gr.Plot(label=None)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
dataset_dropdown.change(plot_unclustered, inputs=[dataset_dropdown], outputs=[unclustered_plot_output])
|
| 302 |
+
demo.load(plot_unclustered, inputs=[dataset_dropdown], outputs=[unclustered_plot_output])
|
| 303 |
+
# Update the button click event to include new parameters for GenieClust
|
| 304 |
+
button.click(
|
| 305 |
+
plot_clustered,
|
| 306 |
+
inputs=[
|
| 307 |
+
dataset_dropdown,
|
| 308 |
+
clustering_method_dropdown,
|
| 309 |
+
kmeans_n_clusters_slider,
|
| 310 |
+
agg_n_clusters_slider,
|
| 311 |
+
agg_linkage_dropdown,
|
| 312 |
+
gmm_n_clusters_slider,
|
| 313 |
+
covariance_type_dropdown,
|
| 314 |
+
genie_n_clusters_slider, # Add Genie parameters
|
| 315 |
+
gini_threshold_slider,
|
| 316 |
+
M_slider,
|
| 317 |
+
hdbscan_min_cluster_size,
|
| 318 |
+
hdbscan_min_samples
|
| 319 |
+
],
|
| 320 |
+
outputs=[clustered_plot_output]
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if __name__ == "__main__":
|
| 324 |
+
demo.launch(debug=True)
|
| 325 |
+
|