Spaces:
Sleeping
Sleeping
from functools import partial | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import NullFormatter | |
import numpy as np | |
from sklearn import datasets, manifold | |
SEED = 0 | |
N_COMPONENTS = 2 | |
np.random.seed(SEED) | |
def get_circles(n_samples): | |
X, color = datasets.make_circles( | |
n_samples=n_samples, | |
factor=0.5, | |
noise=0.05, | |
random_state=SEED | |
) | |
return X, color | |
def get_s_curve(n_samples): | |
X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED) | |
X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy() | |
return X, color | |
def get_uniform_grid(n_samples): | |
x = np.linspace(0, 1, int(np.sqrt(n_samples))) | |
xx, yy = np.meshgrid(x, x) | |
X = np.hstack( | |
[ | |
xx.ravel().reshape(-1, 1), | |
yy.ravel().reshape(-1, 1), | |
] | |
) | |
color = xx.ravel() | |
return X, color | |
DATA_MAPPING = { | |
'circles': get_circles, | |
's-curve': get_s_curve, | |
'uniform grid': get_uniform_grid, | |
} | |
def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool): | |
if isinstance(perplexity, dict): | |
perplexity = perplexity['value'] | |
else: | |
perplexity = int(perplexity) | |
X, color = DATA_MAPPING[dataset](n_samples) | |
if tsne: | |
tsne = manifold.TSNE( | |
n_components=N_COMPONENTS, | |
init="random", | |
random_state=0, | |
perplexity=perplexity, | |
n_iter=400, | |
) | |
Y = tsne.fit_transform(X) | |
else: | |
Y = X | |
fig, ax = plt.subplots(figsize=(7, 7)) | |
ax.scatter(Y[:, 0], Y[:, 1], c=color) | |
ax.xaxis.set_major_formatter(NullFormatter()) | |
ax.yaxis.set_major_formatter(NullFormatter()) | |
ax.axis("tight") | |
return fig | |
title = "t-SNE: The effect of various perplexity values on the shape" | |
description = ( | |
"An illustration of t-SNE on the two concentric circles and the" | |
"S-curve datasets for different perplexity values." | |
) | |
with gr.Blocks(title=title) as demo: | |
gr.HTML(f"<b>{title}</b>") | |
gr.Markdown(description) | |
input_data = gr.Radio( | |
list(DATA_MAPPING), | |
value="circles", | |
label="dataset" | |
) | |
n_samples = gr.Slider( | |
minimum=100, | |
maximum=1000, | |
value=150, | |
step=25, | |
label='Number of Samples' | |
) | |
perplexity = gr.Slider( | |
minimum=2, | |
maximum=100, | |
value=5, | |
step=1, | |
label='Perplexity' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
plot = gr.Plot(label="Original data") | |
fn = partial(plot_data, tsne=False) | |
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
with gr.Column(): | |
plot = gr.Plot(label="t-SNE") | |
fn = partial(plot_data, tsne=True) | |
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot) | |
demo.launch() | |