normflows / app.py
apsys's picture
upl
5a018a0
raw
history blame
2.11 kB
import streamlit as st
import torch
from normflows import nflow
import numpy as np
import seaborn as sns
import pandas as pd
uploaded_file = st.file_uploader("Choose original dataset")
col1,col2,col3 = st.columns(3)
bw = col1.number_input('Scale',value=3.05)
wd = col2.number_input('Weight Decay',value=0.0002)
iters = col3.number_input('Iterations',value=400)
def compute(dim):
api = nflow(dim=dim,latent=16,dataset=uploaded_file)
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=wd)
my_bar = st.progress(0)
for idx in api.train(iters=iters):
my_bar.progress(idx[0]/iters)
my_bar.progress(100)
samples = np.delete(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),np.argmin(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),axis=0),0)
# samples = np.delete(samples,np.argmax(samples,axis=0),0)
# fig, ax = plt.subplots()
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=1000)
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
st.pyplot(w.get_figure())
def random_normal_samples(n, dim=3):
return torch.zeros(n, dim).normal_(mean=0, std=1)
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
return api.scaler.inverse_transform(samples)
with st.form('login_form'):
st.write('Token for generation:')
token = st.text_input()
submit = st.form_submit_button('Submit')
if token in st.secrets['tokens'] and submit:
if uploaded_file is not None:
dims = len(uploaded_file.getvalue().decode("utf-8").split('\n')[0].split(','))-1
samples=compute(dims)
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
elif not uploaded_file:
st.write('Upload your file')
else:
st.markdown('## :red[You dont have access]')