apsys commited on
Commit
e53edb8
1 Parent(s): 8323d56

added files

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from normflows import nflow
4
+ import numpy as np
5
+ import seaborn as sns
6
+ import pandas as pd
7
+
8
+ uploaded_file = st.file_uploader("Choose original dataset")
9
+ bw = st.number_input('Scale',value=3.05)
10
+
11
+
12
+
13
+ def compute():
14
+ api = nflow(dim=8,latent=16,dataset=uploaded_file)
15
+ api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
16
+ api.train(iters=10000)
17
+ samples = np.array(api.model.sample(
18
+ torch.tensor(api.scaled).float()).detach())
19
+
20
+ # fig, ax = plt.subplots()
21
+ g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
22
+
23
+ w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
24
+ st.pyplot(w.get_figure())
25
+
26
+
27
+ def random_normal_samples(n, dim=2):
28
+ return torch.zeros(n, dim).normal_(mean=0, std=1)
29
+
30
+ samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
31
+
32
+ return api.scaler.inverse_transform(samples)
33
+
34
+
35
+
36
+ if uploaded_file is not None:
37
+ samples=compute()
38
+ st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')