Spaces:
Runtime error
Runtime error
Commit
•
9b89eaf
1
Parent(s):
8fd22b4
Upload folder using huggingface_hub
Browse files- requirements.txt +1 -1
- run.ipynb +1 -1
- run.py +2 -1
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
https://gradio-builds.s3.amazonaws.com/
|
2 |
matplotlib>=3.5.2
|
3 |
scikit-learn>=1.0.1
|
|
|
1 |
+
https://gradio-builds.s3.amazonaws.com/1d5b15a2d24387154f2cfb40a36de25b331471d3/gradio-3.47.1-py3-none-any.whl
|
2 |
matplotlib>=3.5.2
|
3 |
scikit-learn>=1.0.1
|
run.ipynb
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"cells":
|
|
|
1 |
+
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Gradio Demo: clustering\n","### This demo built with Blocks generates 9 plots based on the input.\n"," "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q gradio matplotlib>=3.5.2 scikit-learn>=1.0.1 "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import gradio as gr\n","import math\n","from functools import partial\n","import matplotlib.pyplot as plt\n","import numpy as np\n","from sklearn.cluster import (\n"," AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth\n",")\n","from sklearn.datasets import make_blobs, make_circles, make_moons\n","from sklearn.mixture import GaussianMixture\n","from sklearn.neighbors import kneighbors_graph\n","from sklearn.preprocessing import StandardScaler\n","\n","plt.style.use('seaborn-v0_8')\n","SEED = 0\n","MAX_CLUSTERS = 10\n","N_SAMPLES = 1000\n","N_COLS = 3\n","FIGSIZE = 7, 7 # does not affect size in webpage\n","COLORS = [\n"," 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'\n","]\n","if len(COLORS) <= MAX_CLUSTERS:\n"," raise ValueError(\"Not enough different colors for all clusters\")\n","np.random.seed(SEED)\n","\n","\n","def normalize(X):\n"," return StandardScaler().fit_transform(X)\n","\n","\n","def get_regular(n_clusters):\n"," # spiral pattern\n"," centers = [\n"," [0, 0],\n"," [1, 0],\n"," [1, 1],\n"," [0, 1],\n"," [-1, 1],\n"," [-1, 0],\n"," [-1, -1],\n"," [0, -1],\n"," [1, -1],\n"," [2, -1],\n"," ][:n_clusters]\n"," assert len(centers) == n_clusters\n"," X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers,\n"," cluster_std=0.25, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_circles(n_clusters):\n"," X, labels = make_circles(\n"," n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_moons(n_clusters):\n"," X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)\n"," return normalize(X), labels\n","\n","\n","def get_noise(n_clusters):\n"," np.random.seed(SEED)\n"," X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(\n"," 0, n_clusters, size=(N_SAMPLES,))\n"," return normalize(X), labels\n","\n","\n","def get_anisotropic(n_clusters):\n"," X, labels = make_blobs(n_samples=N_SAMPLES,\n"," centers=n_clusters, random_state=170)\n"," transformation = [[0.6, -0.6], [-0.4, 0.8]]\n"," X = np.dot(X, transformation)\n"," return X, labels\n","\n","\n","def get_varied(n_clusters):\n"," cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5,\n"," 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]\n"," assert len(cluster_std) == n_clusters\n"," X, labels = make_blobs(\n"," n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED\n"," )\n"," return normalize(X), labels\n","\n","\n","def get_spiral(n_clusters):\n"," # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html\n"," np.random.seed(SEED)\n"," t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))\n"," x = t * np.cos(t)\n"," y = t * np.sin(t)\n"," X = np.concatenate((x, y))\n"," X += 0.7 * np.random.randn(2, N_SAMPLES)\n"," X = np.ascontiguousarray(X.T)\n","\n"," labels = np.zeros(N_SAMPLES, dtype=int)\n"," return normalize(X), labels\n","\n","\n","DATA_MAPPING = {\n"," 'regular': get_regular,\n"," 'circles': get_circles,\n"," 'moons': get_moons,\n"," 'spiral': get_spiral,\n"," 'noise': get_noise,\n"," 'anisotropic': get_anisotropic,\n"," 'varied': get_varied,\n","}\n","\n","\n","def get_groundtruth_model(X, labels, n_clusters, **kwargs):\n"," # dummy model to show true label distribution\n"," class Dummy:\n"," def __init__(self, y):\n"," self.labels_ = labels\n","\n"," return Dummy(labels)\n","\n","\n","def get_kmeans(X, labels, n_clusters, **kwargs):\n"," model = KMeans(init=\"k-means++\", n_clusters=n_clusters,\n"," n_init=10, random_state=SEED)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_dbscan(X, labels, n_clusters, **kwargs):\n"," model = DBSCAN(eps=0.3)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_agglomerative(X, labels, n_clusters, **kwargs):\n"," connectivity = kneighbors_graph(\n"," X, n_neighbors=n_clusters, include_self=False\n"," )\n"," # make connectivity symmetric\n"," connectivity = 0.5 * (connectivity + connectivity.T)\n"," model = AgglomerativeClustering(\n"," n_clusters=n_clusters, linkage=\"ward\", connectivity=connectivity\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_meanshift(X, labels, n_clusters, **kwargs):\n"," bandwidth = estimate_bandwidth(X, quantile=0.25)\n"," model = MeanShift(bandwidth=bandwidth, bin_seeding=True)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_spectral(X, labels, n_clusters, **kwargs):\n"," model = SpectralClustering(\n"," n_clusters=n_clusters,\n"," eigen_solver=\"arpack\",\n"," affinity=\"nearest_neighbors\",\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_optics(X, labels, n_clusters, **kwargs):\n"," model = OPTICS(\n"," min_samples=7,\n"," xi=0.05,\n"," min_cluster_size=0.1,\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_birch(X, labels, n_clusters, **kwargs):\n"," model = Birch(n_clusters=n_clusters)\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","def get_gaussianmixture(X, labels, n_clusters, **kwargs):\n"," model = GaussianMixture(\n"," n_components=n_clusters, covariance_type=\"full\", random_state=SEED,\n"," )\n"," model.set_params(**kwargs)\n"," return model.fit(X)\n","\n","\n","MODEL_MAPPING = {\n"," 'True labels': get_groundtruth_model,\n"," 'KMeans': get_kmeans,\n"," 'DBSCAN': get_dbscan,\n"," 'MeanShift': get_meanshift,\n"," 'SpectralClustering': get_spectral,\n"," 'OPTICS': get_optics,\n"," 'Birch': get_birch,\n"," 'GaussianMixture': get_gaussianmixture,\n"," 'AgglomerativeClustering': get_agglomerative,\n","}\n","\n","\n","def plot_clusters(ax, X, labels):\n"," set_clusters = set(labels)\n"," # -1 signifiies outliers, which we plot separately\n"," set_clusters.discard(-1)\n"," for label, color in zip(sorted(set_clusters), COLORS):\n"," idx = labels == label\n"," if not sum(idx):\n"," continue\n"," ax.scatter(X[idx, 0], X[idx, 1], color=color)\n","\n"," # show outliers (if any)\n"," idx = labels == -1\n"," if sum(idx):\n"," ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')\n","\n"," ax.grid(None)\n"," ax.set_xticks([])\n"," ax.set_yticks([])\n"," return ax\n","\n","\n","def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):\n"," if isinstance(n_clusters, dict):\n"," n_clusters = n_clusters['value']\n"," else:\n"," n_clusters = int(n_clusters)\n","\n"," X, labels = DATA_MAPPING[dataset](n_clusters)\n"," model = MODEL_MAPPING[clustering_algorithm](\n"," X, labels, n_clusters=n_clusters)\n"," if hasattr(model, \"labels_\"):\n"," y_pred = model.labels_.astype(int)\n"," else:\n"," y_pred = model.predict(X)\n","\n"," fig, ax = plt.subplots(figsize=FIGSIZE)\n","\n"," plot_clusters(ax, X, y_pred)\n"," ax.set_title(clustering_algorithm, fontsize=16)\n","\n"," return fig\n","\n","\n","title = \"Clustering with Scikit-learn\"\n","description = (\n"," \"This example shows how different clustering algorithms work. Simply pick \"\n"," \"the dataset and the number of clusters to see how the clustering algorithms work. \"\n"," \"Colored circles are (predicted) labels and black x are outliers.\"\n",")\n","\n","\n","def iter_grid(n_rows, n_cols):\n"," # create a grid using gradio Block\n"," for _ in range(n_rows):\n"," with gr.Row():\n"," for _ in range(n_cols):\n"," with gr.Column():\n"," yield\n","\n","\n","with gr.Blocks(title=title) as demo:\n"," gr.HTML(f\"<b>{title}</b>\")\n"," gr.Markdown(description)\n","\n"," input_models = list(MODEL_MAPPING)\n"," input_data = gr.Radio(\n"," list(DATA_MAPPING),\n"," value=\"regular\",\n"," label=\"dataset\"\n"," )\n"," input_n_clusters = gr.Slider(\n"," minimum=1,\n"," maximum=MAX_CLUSTERS,\n"," value=4,\n"," step=1,\n"," label='Number of clusters'\n"," )\n"," n_rows = int(math.ceil(len(input_models) / N_COLS))\n"," counter = 0\n"," for _ in iter_grid(n_rows, N_COLS):\n"," if counter >= len(input_models):\n"," break\n","\n"," input_model = input_models[counter]\n"," plot = gr.Plot(label=input_model)\n"," fn = partial(cluster, clustering_algorithm=input_model)\n"," input_data.change(\n"," fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n"," input_n_clusters.change(\n"," fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)\n"," counter += 1\n","\n","demo.launch()"]}],"metadata":{"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":5}
|
run.py
CHANGED
@@ -20,7 +20,8 @@ FIGSIZE = 7, 7 # does not affect size in webpage
|
|
20 |
COLORS = [
|
21 |
'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
|
22 |
]
|
23 |
-
|
|
|
24 |
np.random.seed(SEED)
|
25 |
|
26 |
|
|
|
20 |
COLORS = [
|
21 |
'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
|
22 |
]
|
23 |
+
if len(COLORS) <= MAX_CLUSTERS:
|
24 |
+
raise ValueError("Not enough different colors for all clusters")
|
25 |
np.random.seed(SEED)
|
26 |
|
27 |
|