Jonas Becker commited on
Commit
bc73fb3
1 Parent(s): f67aafa

Updated Layout

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -7,9 +7,13 @@ import disvae
7
  import transforms as trans
8
 
9
  P_MODEL = "model/drilling_ds_btcvae"
 
 
 
10
 
11
  @st.cache_resource
12
  def load_decode_function():
 
13
  sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
14
  vae = disvae.load_model(P_MODEL)
15
  scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
@@ -30,17 +34,25 @@ def load_decode_function():
30
 
31
  decode = load_decode_function()
32
 
33
- st.markdown("**Latent Space Parameters**")
34
- latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
35
- latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
 
 
 
36
 
37
- ts = decode(latent_vector)
38
 
39
- st.markdown("**Generated Time Series**")
 
40
 
41
- fig, ax = plt.subplots(figsize=(8,4))
42
 
43
- ax.plot(ts.ravel())
44
- ax.set_ylim([0,4])
 
 
 
 
45
 
46
- st.pyplot(fig)
 
7
  import transforms as trans
8
 
9
  P_MODEL = "model/drilling_ds_btcvae"
10
+ SAMPLING_TIME = 0.15
11
+
12
+ st.set_page_config(page_title="Drilling VAE")
13
 
14
  @st.cache_resource
15
  def load_decode_function():
16
+
17
  sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
18
  vae = disvae.load_model(P_MODEL)
19
  scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
 
34
 
35
  decode = load_decode_function()
36
 
37
+ col1,col2 = st.columns(2)
38
+
39
+ with col1:
40
+ st.markdown("**Latent Space Parameters**")
41
+ latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
42
+ latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
43
 
44
+ ts = decode(latent_vector)
45
 
46
+ with col2:
47
+ st.markdown("**Generated Time Series**")
48
 
49
+ fig, ax = plt.subplots(figsize=(4,3))
50
 
51
+ time = np.arange(0,len(ts)*SAMPLING_TIME,SAMPLING_TIME)
52
+ ax.plot(time,ts.ravel())
53
+ ax.set_xlabel("Time t [s]")
54
+ ax.set_ylabel("Spindle torque t [Nm]")
55
+ ax.set_ylim([0,4])
56
+ ax.grid()
57
 
58
+ st.pyplot(fig)