caliex commited on
Commit
1c92b51
·
1 Parent(s): e2757a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ from sklearn.datasets import load_digits
6
+ from sklearn.neighbors import KernelDensity
7
+ from sklearn.decomposition import PCA
8
+ from sklearn.model_selection import GridSearchCV
9
+
10
+ def generate_digits(bandwidth, num_samples):
11
+
12
+ # convert bandwidth to integer
13
+ bandwidth = int(bandwidth)
14
+
15
+ # convert num_samples to integer
16
+ num_samples = int(num_samples)
17
+
18
+ # load the data
19
+ digits = load_digits()
20
+
21
+ # project the 64-dimensional data to a lower dimension
22
+ pca = PCA(n_components=15, whiten=False)
23
+ data = pca.fit_transform(digits.data)
24
+
25
+ # use grid search cross-validation to optimize the bandwidth
26
+ params = {"bandwidth": np.logspace(-1, 1, 20)}
27
+ grid = GridSearchCV(KernelDensity(), params)
28
+ grid.fit(data)
29
+
30
+ # use the specified bandwidth to compute the kernel density estimate
31
+ kde = KernelDensity(bandwidth=bandwidth)
32
+ kde.fit(data)
33
+
34
+ # sample new points from the data
35
+ new_data = kde.sample(num_samples, random_state=0)
36
+ new_data = pca.inverse_transform(new_data)
37
+
38
+ # reshape the data into a 4x11 grid
39
+ new_data = new_data.reshape((num_samples, 64))
40
+ real_data = digits.data[:num_samples].reshape((num_samples, 64))
41
+
42
+ # create the plot
43
+ fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
44
+ for j in range(11):
45
+ ax[4, j].set_visible(False)
46
+ for i in range(4):
47
+ index = i * 11 + j # Calculate the correct index
48
+ if index < num_samples:
49
+ im = ax[i, j].imshow(
50
+ real_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
51
+ )
52
+ im.set_clim(0, 16)
53
+ im = ax[i + 5, j].imshow(
54
+ new_data[index].reshape((8, 8)), cmap=plt.cm.binary, interpolation="nearest"
55
+ )
56
+ im.set_clim(0, 16)
57
+ else:
58
+ ax[i, j].axis("off")
59
+ ax[i + 5, j].axis("off")
60
+
61
+ ax[0, 5].set_title("Selection from the input data")
62
+ ax[5, 5].set_title('"New" digits drawn from the kernel density model')
63
+
64
+
65
+ # save the plot to a file
66
+ plt.savefig("digits_plot.png")
67
+
68
+ # return the path to the generated plot
69
+ return "digits_plot.png"
70
+
71
+ # create the Gradio interface
72
+ inputs = [
73
+ gr.inputs.Slider(minimum=1, maximum=10, step=1, label="Bandwidth"),
74
+ gr.inputs.Number(default=44, label="Number of Samples")
75
+ ]
76
+ output = gr.outputs.Image(type="pil")
77
+
78
+ title = "Kernel Density Estimation"
79
+ description = "This example shows how kernel density estimation (KDE), a powerful non-parametric density estimation technique, can be used to learn a generative model for a dataset. With this generative model in place, new samples can be drawn. These new samples reflect the underlying model of the data. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/neighbors/plot_digits_kde_sampling.html"
80
+ examples = [
81
+ [1, 44], # Changed to integer values
82
+ [8, 22], # Changed to integer values
83
+ [7, 51] # Changed to integer values
84
+ ]
85
+
86
+ gr.Interface(generate_digits, inputs, output, title=title, description=description, examples=examples).launch()