normflows / app.py
apsys's picture
added files
e53edb8
raw
history blame
No virus
1.27 kB
import streamlit as st
import torch
from normflows import nflow
import numpy as np
import seaborn as sns
import pandas as pd
uploaded_file = st.file_uploader("Choose original dataset")
bw = st.number_input('Scale',value=3.05)
def compute():
api = nflow(dim=8,latent=16,dataset=uploaded_file)
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
api.train(iters=10000)
samples = np.array(api.model.sample(
torch.tensor(api.scaled).float()).detach())
# fig, ax = plt.subplots()
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
st.pyplot(w.get_figure())
def random_normal_samples(n, dim=2):
return torch.zeros(n, dim).normal_(mean=0, std=1)
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
return api.scaler.inverse_transform(samples)
if uploaded_file is not None:
samples=compute()
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')