Hnabil commited on
Commit
10d6c31
1 Parent(s): 0b1d824

Add application files

Browse files
Files changed (2) hide show
  1. app.py +126 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.ticker import NullFormatter
6
+ import numpy as np
7
+ from sklearn import datasets, manifold
8
+
9
+
10
+ SEED = 0
11
+ N_COMPONENTS = 2
12
+ np.random.seed(SEED)
13
+
14
+
15
+ def get_circles(n_samples):
16
+ X, color = datasets.make_circles(
17
+ n_samples=n_samples,
18
+ factor=0.5,
19
+ noise=0.05,
20
+ random_state=SEED
21
+ )
22
+ return X, color
23
+
24
+
25
+ def get_s_curve(n_samples):
26
+ X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED)
27
+ X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy()
28
+ return X, color
29
+
30
+
31
+ def get_uniform_grid(n_samples):
32
+ x = np.linspace(0, 1, int(np.sqrt(n_samples)))
33
+ xx, yy = np.meshgrid(x, x)
34
+ X = np.hstack(
35
+ [
36
+ xx.ravel().reshape(-1, 1),
37
+ yy.ravel().reshape(-1, 1),
38
+ ]
39
+ )
40
+ color = xx.ravel()
41
+ return X, color
42
+
43
+
44
+ DATA_MAPPING = {
45
+ 'circles': get_circles,
46
+ 's-curve': get_s_curve,
47
+ 'uniform grid': get_uniform_grid,
48
+ }
49
+
50
+
51
+
52
+ def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool):
53
+ if isinstance(perplexity, dict):
54
+ perplexity = perplexity['value']
55
+ else:
56
+ perplexity = int(perplexity)
57
+
58
+ X, color = DATA_MAPPING[dataset](n_samples)
59
+ if tsne:
60
+ tsne = manifold.TSNE(
61
+ n_components=N_COMPONENTS,
62
+ init="random",
63
+ random_state=0,
64
+ perplexity=perplexity,
65
+ n_iter=400,
66
+ )
67
+ Y = tsne.fit_transform(X)
68
+ else:
69
+ Y = X
70
+
71
+ fig, ax = plt.subplots(figsize=(7, 7))
72
+
73
+ ax.scatter(Y[:, 0], Y[:, 1], c=color)
74
+ ax.xaxis.set_major_formatter(NullFormatter())
75
+ ax.yaxis.set_major_formatter(NullFormatter())
76
+ ax.axis("tight")
77
+
78
+ return fig
79
+
80
+
81
+ title = "t-SNE: The effect of various perplexity values on the shape"
82
+ description = (
83
+ "An illustration of t-SNE on the two concentric circles and the"
84
+ "S-curve datasets for different perplexity values."
85
+ )
86
+
87
+
88
+ with gr.Blocks(title=title) as demo:
89
+ gr.HTML(f"<b>{title}</b>")
90
+ gr.Markdown(description)
91
+
92
+ input_data = gr.Radio(
93
+ list(DATA_MAPPING),
94
+ value="circles",
95
+ label="dataset"
96
+ )
97
+ n_samples = gr.Slider(
98
+ minimum=100,
99
+ maximum=1000,
100
+ value=150,
101
+ step=25,
102
+ label='Number of Samples'
103
+ )
104
+ perplexity = gr.Slider(
105
+ minimum=2,
106
+ maximum=100,
107
+ value=5,
108
+ step=1,
109
+ label='Perplexity'
110
+ )
111
+ with gr.Row():
112
+ with gr.Column():
113
+ plot = gr.Plot(label="Original data")
114
+ fn = partial(plot_data, tsne=False)
115
+ input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
116
+ perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
117
+ n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
118
+ with gr.Column():
119
+ plot = gr.Plot(label="t-SNE")
120
+ fn = partial(plot_data, tsne=True)
121
+ input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
122
+ perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
123
+ n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
124
+
125
+
126
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ scikit-learn