bug fix
Browse files
app.py
CHANGED
@@ -19,14 +19,16 @@ def compute(dim):
|
|
19 |
|
20 |
my_bar = st.progress(0)
|
21 |
|
22 |
-
for idx in api.train(iters=
|
23 |
-
my_bar.progress(idx[0]/
|
24 |
-
|
25 |
samples = np.array(api.model.sample(
|
26 |
torch.tensor(api.scaled).float()).detach())
|
|
|
|
|
27 |
|
28 |
# fig, ax = plt.subplots()
|
29 |
-
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=
|
30 |
|
31 |
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
|
32 |
st.pyplot(w.get_figure())
|
@@ -39,9 +41,8 @@ def compute(dim):
|
|
39 |
|
40 |
return api.scaler.inverse_transform(samples)
|
41 |
|
42 |
-
|
43 |
|
44 |
if uploaded_file is not None:
|
45 |
-
|
46 |
-
samples=compute(
|
47 |
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
|
|
|
19 |
|
20 |
my_bar = st.progress(0)
|
21 |
|
22 |
+
for idx in api.train(iters=iters):
|
23 |
+
my_bar.progress(idx[0]/iters)
|
24 |
+
my_bar.progress(100)
|
25 |
samples = np.array(api.model.sample(
|
26 |
torch.tensor(api.scaled).float()).detach())
|
27 |
+
|
28 |
+
|
29 |
|
30 |
# fig, ax = plt.subplots()
|
31 |
+
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=1000)
|
32 |
|
33 |
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
|
34 |
st.pyplot(w.get_figure())
|
|
|
41 |
|
42 |
return api.scaler.inverse_transform(samples)
|
43 |
|
|
|
44 |
|
45 |
if uploaded_file is not None:
|
46 |
+
dims = len(uploaded_file.getvalue().decode("utf-8").split('\n')[0].split(','))-1
|
47 |
+
samples=compute(dims)
|
48 |
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
|