Jonas Becker commited on
Commit
46cd050
1 Parent(s): 7f19394

caching model generation

Browse files
Files changed (1) hide show
  1. app.py +22 -27
app.py CHANGED
@@ -8,32 +8,33 @@ import transforms as trans
8
 
9
  P_MODEL = "model/drilling_ds_btcvae"
10
 
11
- # Decode Funktion --------------------------------------------------
12
- sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
13
- vae = disvae.load_model(P_MODEL)
14
- scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
15
- imaging = trans.SumField()
16
-
17
- _dec = trans.sequential_function(
18
- sorter.inv,
19
- vae.decoder,
20
- scaler.inv
21
- )
22
-
23
- def decode(latent):
24
- with torch.no_grad():
25
- return trans.np_sample(_dec)(latent)
 
 
 
 
26
 
27
- img2ts = trans.np_sample(imaging.inv)
28
 
29
- # GUI -----------------------------------------------------------
30
 
31
  latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
32
  latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
33
 
34
- value = decode(latent_vector)
35
-
36
- ts = img2ts(value)
37
 
38
  df = pd.DataFrame({
39
  "x":np.arange(len(ts)),
@@ -41,10 +42,4 @@ df = pd.DataFrame({
41
  }
42
  )
43
 
44
- st.line_chart(df,x="x",y="y")
45
- st.write(ts)
46
- # st.write(value)
47
- # st.image(value, use_column_width="always")
48
-
49
- # x = st.slider("Select a value")
50
- # st.write(x, "squared is", x * x)
 
8
 
9
  P_MODEL = "model/drilling_ds_btcvae"
10
 
11
+ @st.cache_resource
12
+ def load_decode_function():
13
+ sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
14
+ vae = disvae.load_model(P_MODEL)
15
+ scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
16
+ imaging = trans.SumField()
17
+
18
+ _dec = trans.sequential_function(
19
+ sorter.inv,
20
+ vae.decoder,
21
+ scaler.inv,
22
+ imaging.inv
23
+ )
24
+
25
+ def decode(latent):
26
+ with torch.no_grad():
27
+ return trans.np_sample(_dec)(latent)
28
+
29
+ return decode
30
 
31
+ decode = load_decode_function()
32
 
 
33
 
34
  latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
35
  latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
36
 
37
+ ts = decode(latent_vector)
 
 
38
 
39
  df = pd.DataFrame({
40
  "x":np.arange(len(ts)),
 
42
  }
43
  )
44
 
45
+ st.line_chart(df,x="x",y="y")