|
|
|
|
|
import gradio as gr |
|
|
|
from sklearn.cluster import OPTICS, cluster_optics_dbscan |
|
import matplotlib.gridspec as gridspec |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
plt.switch_backend("agg") |
|
|
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
radius_size=gr.themes.sizes.radius_sm, |
|
font=[ |
|
gr.themes.GoogleFont("Open Sans"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
) |
|
|
|
|
|
def do_submit(n_points_per_cluster, min_samples, xi, min_cluster_size): |
|
|
|
np.random.seed(0) |
|
n_points_per_cluster = int(n_points_per_cluster) |
|
|
|
C1 = [-5, -2] + 0.8 * np.random.randn(n_points_per_cluster, 2) |
|
C2 = [4, -1] + 0.1 * np.random.randn(n_points_per_cluster, 2) |
|
C3 = [1, -2] + 0.2 * np.random.randn(n_points_per_cluster, 2) |
|
C4 = [-2, 3] + 0.3 * np.random.randn(n_points_per_cluster, 2) |
|
C5 = [3, -2] + 1.6 * np.random.randn(n_points_per_cluster, 2) |
|
C6 = [5, 6] + 2 * np.random.randn(n_points_per_cluster, 2) |
|
X = np.vstack((C1, C2, C3, C4, C5, C6)) |
|
|
|
clust = OPTICS( |
|
min_samples=int(min_samples), |
|
xi=float(xi), |
|
min_cluster_size=float(min_cluster_size), |
|
) |
|
|
|
|
|
clust.fit(X) |
|
|
|
labels_050 = cluster_optics_dbscan( |
|
reachability=clust.reachability_, |
|
core_distances=clust.core_distances_, |
|
ordering=clust.ordering_, |
|
eps=0.5, |
|
) |
|
labels_200 = cluster_optics_dbscan( |
|
reachability=clust.reachability_, |
|
core_distances=clust.core_distances_, |
|
ordering=clust.ordering_, |
|
eps=2, |
|
) |
|
|
|
space = np.arange(len(X)) |
|
reachability = clust.reachability_[clust.ordering_] |
|
labels = clust.labels_[clust.ordering_] |
|
|
|
plt.figure(figsize=(10, 6)) |
|
G = gridspec.GridSpec(2, 3) |
|
ax1 = plt.subplot(G[0, :]) |
|
ax2 = plt.subplot(G[1, 0]) |
|
ax3 = plt.subplot(G[1, 1]) |
|
ax4 = plt.subplot(G[1, 2]) |
|
|
|
|
|
colors = ["g.", "r.", "b.", "y.", "c."] |
|
for klass, color in zip(range(0, 5), colors): |
|
Xk = space[labels == klass] |
|
Rk = reachability[labels == klass] |
|
ax1.plot(Xk, Rk, color, alpha=0.3) |
|
ax1.plot(space[labels == -1], reachability[labels == -1], "k.", alpha=0.3) |
|
ax1.plot(space, np.full_like(space, 2.0, dtype=float), "k-", alpha=0.5) |
|
ax1.plot(space, np.full_like(space, 0.5, dtype=float), "k-.", alpha=0.5) |
|
ax1.set_ylabel("Reachability (epsilon distance)") |
|
ax1.set_title("Reachability Plot") |
|
|
|
|
|
colors = ["g.", "r.", "b.", "y.", "c."] |
|
for klass, color in zip(range(0, 5), colors): |
|
Xk = X[clust.labels_ == klass] |
|
ax2.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) |
|
ax2.plot(X[clust.labels_ == -1, 0], X[clust.labels_ == -1, 1], "k+", alpha=0.1) |
|
ax2.set_title("Automatic Clustering\nOPTICS") |
|
|
|
|
|
colors = ["g.", "r.", "b.", "c."] |
|
for klass, color in zip(range(0, 4), colors): |
|
Xk = X[labels_050 == klass] |
|
ax3.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) |
|
ax3.plot(X[labels_050 == -1, 0], X[labels_050 == -1, 1], "k+", alpha=0.1) |
|
ax3.set_title("Clustering at 0.5 epsilon cut\nDBSCAN") |
|
|
|
|
|
colors = ["g.", "m.", "y.", "c."] |
|
for klass, color in zip(range(0, 4), colors): |
|
Xk = X[labels_200 == klass] |
|
ax4.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3) |
|
ax4.plot(X[labels_200 == -1, 0], X[labels_200 == -1, 1], "k+", alpha=0.1) |
|
ax4.set_title("Clustering at 2.0 epsilon cut\nDBSCAN") |
|
|
|
plt.tight_layout() |
|
|
|
return plt |
|
|
|
|
|
title = "Demo of OPTICS clustering algorithm" |
|
with gr.Blocks(title=title, theme=theme) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown( |
|
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html)" |
|
) |
|
|
|
gr.Markdown( |
|
"Finds core samples of high density and expands clusters from them. This example uses data that is \ |
|
generated so that the clusters have different densities. The [OPTICS](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS) is first used with its Xi cluster detection \ |
|
method, and then setting specific thresholds on the reachability, which corresponds to [DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN). We can see that \ |
|
the different clusters of OPTICS’s Xi method can be recovered with different choices of thresholds in DBSCAN." |
|
) |
|
|
|
with gr.Row().style(equal_height=True): |
|
with gr.Column(scale=0.75): |
|
n_points_per_cluster = gr.Slider( |
|
minimum=200, |
|
maximum=500, |
|
label="Number of points per cluster", |
|
step=50, |
|
value=250, |
|
) |
|
with gr.Row(visible=False): |
|
gr.Markdown("##") |
|
|
|
min_samples = gr.Slider( |
|
minimum=10, |
|
maximum=100, |
|
label="OPTICS - Minimum number of samples", |
|
step=5, |
|
value=50, |
|
info="The number of samples in a neighborhood for a point to be considered as a core point.", |
|
) |
|
with gr.Row(visible=False): |
|
gr.Markdown("##") |
|
|
|
xi = gr.Slider( |
|
minimum=0, |
|
maximum=0.2, |
|
label="OPTICS - Xi", |
|
step=0.01, |
|
value=0.05, |
|
info="Determines the minimum steepness on the reachability plot that constitutes a cluster boundary. ", |
|
) |
|
with gr.Row(visible=False): |
|
gr.Markdown("##") |
|
min_cluster_size = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.1, |
|
label="OPTICS - Minimum cluster size", |
|
step=0.01, |
|
value=0.05, |
|
info="Minimum number of samples in an OPTICS cluster, expressed as an absolute number or a fraction of the number of samples (rounded to be at least 2).", |
|
) |
|
|
|
plt_out = gr.Plot() |
|
|
|
n_points_per_cluster.change( |
|
do_submit, |
|
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], |
|
outputs=plt_out, |
|
) |
|
min_samples.change( |
|
do_submit, |
|
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], |
|
outputs=plt_out, |
|
) |
|
xi.change( |
|
do_submit, |
|
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], |
|
outputs=plt_out, |
|
) |
|
min_cluster_size.change( |
|
do_submit, |
|
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size], |
|
outputs=plt_out, |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|