Hnabil's picture
Add application files
10d6c31
raw
history blame
3.28 kB
from functools import partial
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import numpy as np
from sklearn import datasets, manifold
SEED = 0
N_COMPONENTS = 2
np.random.seed(SEED)
def get_circles(n_samples):
X, color = datasets.make_circles(
n_samples=n_samples,
factor=0.5,
noise=0.05,
random_state=SEED
)
return X, color
def get_s_curve(n_samples):
X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED)
X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy()
return X, color
def get_uniform_grid(n_samples):
x = np.linspace(0, 1, int(np.sqrt(n_samples)))
xx, yy = np.meshgrid(x, x)
X = np.hstack(
[
xx.ravel().reshape(-1, 1),
yy.ravel().reshape(-1, 1),
]
)
color = xx.ravel()
return X, color
DATA_MAPPING = {
'circles': get_circles,
's-curve': get_s_curve,
'uniform grid': get_uniform_grid,
}
def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool):
if isinstance(perplexity, dict):
perplexity = perplexity['value']
else:
perplexity = int(perplexity)
X, color = DATA_MAPPING[dataset](n_samples)
if tsne:
tsne = manifold.TSNE(
n_components=N_COMPONENTS,
init="random",
random_state=0,
perplexity=perplexity,
n_iter=400,
)
Y = tsne.fit_transform(X)
else:
Y = X
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(Y[:, 0], Y[:, 1], c=color)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis("tight")
return fig
title = "t-SNE: The effect of various perplexity values on the shape"
description = (
"An illustration of t-SNE on the two concentric circles and the"
"S-curve datasets for different perplexity values."
)
with gr.Blocks(title=title) as demo:
gr.HTML(f"<b>{title}</b>")
gr.Markdown(description)
input_data = gr.Radio(
list(DATA_MAPPING),
value="circles",
label="dataset"
)
n_samples = gr.Slider(
minimum=100,
maximum=1000,
value=150,
step=25,
label='Number of Samples'
)
perplexity = gr.Slider(
minimum=2,
maximum=100,
value=5,
step=1,
label='Perplexity'
)
with gr.Row():
with gr.Column():
plot = gr.Plot(label="Original data")
fn = partial(plot_data, tsne=False)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
with gr.Column():
plot = gr.Plot(label="t-SNE")
fn = partial(plot_data, tsne=True)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.launch()