yoyolicoris commited on
Commit
4059958
·
1 Parent(s): 8a83a8e

add preset selection dropdown, fix incorrect x2z and json not updated

Browse files
Files changed (1) hide show
  1. app.py +48 -14
app.py CHANGED
@@ -47,6 +47,8 @@ CONFIG_PATH = "presets/rt_config.yaml"
47
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
48
  INFO_PATH = "presets/internal/info.json"
49
  MASK_PATH = "presets/internal/feature_mask.npy"
 
 
50
 
51
 
52
  with open(CONFIG_PATH) as fp:
@@ -56,16 +58,21 @@ with open(CONFIG_PATH) as fp:
56
  global_fx = instantiate(fx_config)
57
  global_fx.eval()
58
 
 
 
 
 
 
59
  pca_params = np.load(PCA_PARAM_FILE)
60
  mean = pca_params["mean"]
61
  cov = pca_params["cov"]
62
  eigvals, eigvecs = np.linalg.eigh(cov)
63
  eigvals = np.flip(eigvals, axis=0)
64
  eigvecs = np.flip(eigvecs, axis=1)
65
- U = eigvecs * np.sqrt(eigvals)
66
- U = torch.from_numpy(U).float()
67
  mean = torch.from_numpy(mean).float()
68
- feature_mask = torch.from_numpy(np.load(MASK_PATH))
69
  # Global latent variable
70
  # z = torch.zeros_like(mean)
71
 
@@ -104,7 +111,7 @@ meter = pyln.Meter(44100)
104
  def z2x(z):
105
  # close all figures to avoid too many open figures
106
  plt.close("all")
107
- x = U @ z + mean
108
  # # print(z)
109
  # fx.load_state_dict(vec2dict(x), strict=False)
110
  # fx.apply(partial(clip_delay_eq_Q, Q=0.707))
@@ -123,7 +130,7 @@ def fx2x(fx):
123
  @torch.no_grad()
124
  def x2z(x):
125
  z = U.T @ (x - mean)
126
- return z
127
 
128
 
129
  @torch.no_grad()
@@ -393,13 +400,23 @@ with gr.Blocks() as demo:
393
 
394
  sliders = [s1, s2, s3, s4]
395
 
396
- extra_pc_dropdown = gr.Dropdown(
397
- list(range(NUMBER_OF_PCS + 1, mean.numel() + 1)),
398
- label=f"PC > {NUMBER_OF_PCS}",
399
- info="Select which extra PC to adjust",
400
- interactive=True,
401
- )
402
- extra_slider = default_pc_slider(label="Extra PC")
 
 
 
 
 
 
 
 
 
 
403
 
404
  with gr.Column():
405
  audio_output = default_audio_block(label="Output Audio", interactive=False)
@@ -853,8 +870,15 @@ with gr.Blocks() as demo:
853
  t60_plot,
854
  ]
855
 
856
- update_all = lambda z, fx, i: update_pc(z, i) + update_fx(fx) + update_plots(fx)
857
- update_all_outputs = update_pc_outputs + update_fx_outputs + update_plots_outputs
 
 
 
 
 
 
 
858
 
859
  random_button.click(
860
  chain_functions(
@@ -912,4 +936,14 @@ with gr.Blocks() as demo:
912
  outputs=extra_slider,
913
  )
914
 
 
 
 
 
 
 
 
 
 
 
915
  demo.launch()
 
47
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
48
  INFO_PATH = "presets/internal/info.json"
49
  MASK_PATH = "presets/internal/feature_mask.npy"
50
+ PRESET_PATH = "presets/internal/raw_params.npy"
51
+ TRAIN_INDEX_PATH = "presets/internal/train_index.npy"
52
 
53
 
54
  with open(CONFIG_PATH) as fp:
 
58
  global_fx = instantiate(fx_config)
59
  global_fx.eval()
60
 
61
+ raw_params = torch.from_numpy(np.load(PRESET_PATH))
62
+ train_index = torch.from_numpy(np.load(TRAIN_INDEX_PATH))
63
+ feature_mask = torch.from_numpy(np.load(MASK_PATH))
64
+ presets = raw_params[train_index][:, feature_mask].contiguous()
65
+
66
  pca_params = np.load(PCA_PARAM_FILE)
67
  mean = pca_params["mean"]
68
  cov = pca_params["cov"]
69
  eigvals, eigvecs = np.linalg.eigh(cov)
70
  eigvals = np.flip(eigvals, axis=0)
71
  eigvecs = np.flip(eigvecs, axis=1)
72
+ eigsqrt = torch.from_numpy(eigvals.copy()).float().sqrt()
73
+ U = torch.from_numpy(eigvecs.copy()).float()
74
  mean = torch.from_numpy(mean).float()
75
+
76
  # Global latent variable
77
  # z = torch.zeros_like(mean)
78
 
 
111
  def z2x(z):
112
  # close all figures to avoid too many open figures
113
  plt.close("all")
114
+ x = U @ (z * eigsqrt) + mean
115
  # # print(z)
116
  # fx.load_state_dict(vec2dict(x), strict=False)
117
  # fx.apply(partial(clip_delay_eq_Q, Q=0.707))
 
130
  @torch.no_grad()
131
  def x2z(x):
132
  z = U.T @ (x - mean)
133
+ return z / eigsqrt
134
 
135
 
136
  @torch.no_grad()
 
400
 
401
  sliders = [s1, s2, s3, s4]
402
 
403
+ with gr.Row():
404
+ with gr.Column():
405
+ extra_pc_dropdown = gr.Dropdown(
406
+ list(range(NUMBER_OF_PCS + 1, mean.numel() + 1)),
407
+ label=f"PC > {NUMBER_OF_PCS}",
408
+ info="Select which extra PC to adjust",
409
+ interactive=True,
410
+ )
411
+ extra_slider = default_pc_slider(label="Extra PC")
412
+
413
+ preset_dropdown = gr.Dropdown(
414
+ ["none"] + list(range(1, presets.shape[0] + 1)),
415
+ value="none",
416
+ label=f"Select Preset (1-{presets.shape[0]})",
417
+ info="Select a preset to load (this will override the current settings)",
418
+ interactive=True,
419
+ )
420
 
421
  with gr.Column():
422
  audio_output = default_audio_block(label="Output Audio", interactive=False)
 
870
  t60_plot,
871
  ]
872
 
873
+ update_all = (
874
+ lambda z, fx, i: update_pc(z, i)
875
+ + update_fx(fx)
876
+ + update_plots(fx)
877
+ + [model2json(fx)]
878
+ )
879
+ update_all_outputs = (
880
+ update_pc_outputs + update_fx_outputs + update_plots_outputs + [json_output]
881
+ )
882
 
883
  random_button.click(
884
  chain_functions(
 
936
  outputs=extra_slider,
937
  )
938
 
939
+ preset_dropdown.input(
940
+ chain_functions(
941
+ lambda i, _: (mean if i == "none" else presets[i - 1], _),
942
+ lambda x, i: (x2z(x), x, vec2fx(x), i),
943
+ lambda z, x, fx, i: [z, x] + update_all(z, fx, i),
944
+ ),
945
+ inputs=[preset_dropdown, extra_pc_dropdown],
946
+ outputs=[z, fx_params] + update_all_outputs,
947
+ )
948
+
949
  demo.launch()