import numpy as np import streamlit as st import torch import matplotlib.pyplot as plt import disvae import transforms as trans P_MODEL = "model/drilling_ds_btcvae" SAMPLING_TIME = 0.15 st.set_page_config(page_title="Drilling VAE") @st.cache_resource def load_decode_function(): sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) vae = disvae.load_model(P_MODEL) scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6) imaging = trans.SumField() _dec = trans.sequential_function( sorter.inv, vae.decoder, scaler.inv, imaging.inv ) def decode(latent): with torch.no_grad(): return trans.np_sample(_dec)(latent) return decode decode = load_decode_function() col1,col2 = st.columns(2) with col1: st.markdown("**Latent Space Parameters**") latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) ts = decode(latent_vector) with col2: st.markdown("**Generated Time Series**") fig, ax = plt.subplots(figsize=(4,3)) time = np.arange(0,len(ts)*SAMPLING_TIME,SAMPLING_TIME) ax.plot(time,ts.ravel()) ax.set_xlabel("Time t [s]") ax.set_ylabel("Spindle torque t [Nm]") ax.set_ylim([0,4]) ax.grid() st.pyplot(fig)