File size: 2,183 Bytes
e53edb8
 
 
 
 
 
 
 
5c89480
 
 
 
e53edb8
 
 
5c89480
 
 
75bf717
bfd5c12
75bf717
240432b
 
 
0595c23
154b5ae
240432b
 
e53edb8
 
240432b
e53edb8
 
 
 
 
d409f21
e53edb8
 
 
 
 
5a018a0
 
 
4f60d3d
5a018a0
 
 
 
 
 
 
 
e53edb8
5a018a0
 
e53edb8
c31f537
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
import torch
from normflows import nflow
import numpy as np
import seaborn as sns
import pandas as pd

uploaded_file = st.file_uploader("Choose original dataset")
col1,col2,col3 = st.columns(3)
bw = col1.number_input('Scale',value=3.05)
wd = col2.number_input('Weight Decay',value=0.0002)
iters = col3.number_input('Iterations',value=400)



def compute(dim):
    api = nflow(dim=dim,latent=16,dataset=uploaded_file)
    api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=wd)
    
    my_bar = st.progress(0)
    
    for idx in api.train(iters=iters):
        my_bar.progress(idx[0]/iters)
    my_bar.progress(100)
    samples = np.delete(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),np.argmin(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),axis=0),0)
    # samples = np.delete(samples,np.argmax(samples,axis=0),0)
    
    

        # fig, ax = plt.subplots()
    g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=1000)

    w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
    st.pyplot(w.get_figure())
    
    
    def random_normal_samples(n, dim=3):
        return torch.zeros(n, dim).normal_(mean=0, std=1)
    
    samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
    
    return api.scaler.inverse_transform(samples)

with st.form('login_form'):
    st.write('Token for generation:')
    token = st.text_input('Token')
    submit = st.form_submit_button('Submit')

if token in st.secrets['tokens'] and submit:
    
    if uploaded_file is not None:
        dims = len(uploaded_file.getvalue().decode("utf-8").split('\n')[0].split(','))-1
        samples=compute(dims)
        st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
    
    elif not uploaded_file:
        st.write('Upload your file')
    
else:
        st.markdown('## :red[You dont have access]')
        st.markdown('Buy tokens here: [@advprop](https://adprop.t.me)')