rashmi's picture
Update app.py
7e7a15b
# Scikit learn example https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html
import gradio as gr
from sklearn.cluster import OPTICS, cluster_optics_dbscan
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
plt.switch_backend("agg")
# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
def do_submit(n_points_per_cluster, min_samples, xi, min_cluster_size):
# # Generate sample data
np.random.seed(0)
n_points_per_cluster = int(n_points_per_cluster)
C1 = [-5, -2] + 0.8 * np.random.randn(n_points_per_cluster, 2)
C2 = [4, -1] + 0.1 * np.random.randn(n_points_per_cluster, 2)
C3 = [1, -2] + 0.2 * np.random.randn(n_points_per_cluster, 2)
C4 = [-2, 3] + 0.3 * np.random.randn(n_points_per_cluster, 2)
C5 = [3, -2] + 1.6 * np.random.randn(n_points_per_cluster, 2)
C6 = [5, 6] + 2 * np.random.randn(n_points_per_cluster, 2)
X = np.vstack((C1, C2, C3, C4, C5, C6))
clust = OPTICS(
min_samples=int(min_samples),
xi=float(xi),
min_cluster_size=float(min_cluster_size),
)
# Run the fit
clust.fit(X)
labels_050 = cluster_optics_dbscan(
reachability=clust.reachability_,
core_distances=clust.core_distances_,
ordering=clust.ordering_,
eps=0.5,
)
labels_200 = cluster_optics_dbscan(
reachability=clust.reachability_,
core_distances=clust.core_distances_,
ordering=clust.ordering_,
eps=2,
)
space = np.arange(len(X))
reachability = clust.reachability_[clust.ordering_]
labels = clust.labels_[clust.ordering_]
plt.figure(figsize=(10, 6))
G = gridspec.GridSpec(2, 3)
ax1 = plt.subplot(G[0, :])
ax2 = plt.subplot(G[1, 0])
ax3 = plt.subplot(G[1, 1])
ax4 = plt.subplot(G[1, 2])
# Reachability plot
colors = ["g.", "r.", "b.", "y.", "c."]
for klass, color in zip(range(0, 5), colors):
Xk = space[labels == klass]
Rk = reachability[labels == klass]
ax1.plot(Xk, Rk, color, alpha=0.3)
ax1.plot(space[labels == -1], reachability[labels == -1], "k.", alpha=0.3)
ax1.plot(space, np.full_like(space, 2.0, dtype=float), "k-", alpha=0.5)
ax1.plot(space, np.full_like(space, 0.5, dtype=float), "k-.", alpha=0.5)
ax1.set_ylabel("Reachability (epsilon distance)")
ax1.set_title("Reachability Plot")
# OPTICS
colors = ["g.", "r.", "b.", "y.", "c."]
for klass, color in zip(range(0, 5), colors):
Xk = X[clust.labels_ == klass]
ax2.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
ax2.plot(X[clust.labels_ == -1, 0], X[clust.labels_ == -1, 1], "k+", alpha=0.1)
ax2.set_title("Automatic Clustering\nOPTICS")
# DBSCAN at 0.5
colors = ["g.", "r.", "b.", "c."]
for klass, color in zip(range(0, 4), colors):
Xk = X[labels_050 == klass]
ax3.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
ax3.plot(X[labels_050 == -1, 0], X[labels_050 == -1, 1], "k+", alpha=0.1)
ax3.set_title("Clustering at 0.5 epsilon cut\nDBSCAN")
# DBSCAN at 2.
colors = ["g.", "m.", "y.", "c."]
for klass, color in zip(range(0, 4), colors):
Xk = X[labels_200 == klass]
ax4.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
ax4.plot(X[labels_200 == -1, 0], X[labels_200 == -1, 1], "k+", alpha=0.1)
ax4.set_title("Clustering at 2.0 epsilon cut\nDBSCAN")
plt.tight_layout()
return plt
title = "Demo of OPTICS clustering algorithm"
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html)"
)
gr.Markdown(
"Finds core samples of high density and expands clusters from them. This example uses data that is \
generated so that the clusters have different densities. The [OPTICS](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS) is first used with its Xi cluster detection \
method, and then setting specific thresholds on the reachability, which corresponds to [DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN). We can see that \
the different clusters of OPTICS’s Xi method can be recovered with different choices of thresholds in DBSCAN."
)
with gr.Row().style(equal_height=True):
with gr.Column(scale=0.75):
n_points_per_cluster = gr.Slider(
minimum=200,
maximum=500,
label="Number of points per cluster",
step=50,
value=250,
)
with gr.Row(visible=False):
gr.Markdown("##")
min_samples = gr.Slider(
minimum=10,
maximum=100,
label="OPTICS - Minimum number of samples",
step=5,
value=50,
info="The number of samples in a neighborhood for a point to be considered as a core point.",
)
with gr.Row(visible=False):
gr.Markdown("##")
xi = gr.Slider(
minimum=0,
maximum=0.2,
label="OPTICS - Xi",
step=0.01,
value=0.05,
info="Determines the minimum steepness on the reachability plot that constitutes a cluster boundary. ",
)
with gr.Row(visible=False):
gr.Markdown("##")
min_cluster_size = gr.Slider(
minimum=0.01,
maximum=0.1,
label="OPTICS - Minimum cluster size",
step=0.01,
value=0.05,
info="Minimum number of samples in an OPTICS cluster, expressed as an absolute number or a fraction of the number of samples (rounded to be at least 2).",
)
plt_out = gr.Plot()
n_points_per_cluster.change(
do_submit,
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size],
outputs=plt_out,
)
min_samples.change(
do_submit,
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size],
outputs=plt_out,
)
xi.change(
do_submit,
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size],
outputs=plt_out,
)
min_cluster_size.change(
do_submit,
inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size],
outputs=plt_out,
)
if __name__ == "__main__":
demo.launch()