yoyolicoris commited on
Commit
0529094
·
1 Parent(s): e3bff8a

feat: add plotting functionality for PEQ frequency response

Browse files
Files changed (1) hide show
  1. app.py +91 -29
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
  import yaml
5
  import json
6
  import pyloudnorm as pyln
7
  from hydra.utils import instantiate
8
- from random import normalvariate
9
  from soxr import resample
10
  from functools import partial
11
 
12
  from modules.utils import chain_functions, vec2statedict, get_chunks
13
  from modules.fx import clip_delay_eq_Q
 
14
 
15
 
16
  title_md = "# Vocal Effects Generator"
@@ -41,11 +42,13 @@ TEMPERATURE = 0.7
41
  CONFIG_PATH = "presets/rt_config.yaml"
42
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
43
  INFO_PATH = "presets/internal/info.json"
 
44
 
45
 
46
  with open(CONFIG_PATH) as fp:
47
  fx_config = yaml.safe_load(fp)["model"]
48
 
 
49
  fx = instantiate(fx_config)
50
  fx.eval()
51
 
@@ -58,6 +61,8 @@ eigvecs = np.flip(eigvecs, axis=1)[:, :75]
58
  U = eigvecs * np.sqrt(eigvals)
59
  U = torch.from_numpy(U).float()
60
  mean = torch.from_numpy(mean).float()
 
 
61
  z = torch.zeros(75)
62
 
63
  with open(INFO_PATH) as f:
@@ -85,11 +90,35 @@ vec2dict = partial(
85
  )
86
  ),
87
  )
 
88
 
89
 
90
  meter = pyln.Meter(44100)
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @torch.no_grad()
94
  def inference(audio):
95
  sr, y = audio
@@ -107,21 +136,6 @@ def inference(audio):
107
  if y.shape[1] != 1:
108
  y = y.mean(dim=1, keepdim=True)
109
 
110
- # M = eigvals.shape[0]
111
- # z = torch.cat(
112
- # [
113
- # torch.tensor([float(x) for x in pcs]),
114
- # (
115
- # torch.randn(M - len(pcs)) * TEMPERATURE
116
- # if randomise_rest
117
- # else torch.zeros(M - len(pcs))
118
- # ),
119
- # ]
120
- # )
121
- x = U @ z + mean
122
- # print(z)
123
-
124
- fx.load_state_dict(vec2dict(x), strict=False)
125
  fx.apply(partial(clip_delay_eq_Q, Q=0.707))
126
 
127
  rendered = fx(y).squeeze(0).T.numpy()
@@ -161,6 +175,23 @@ def model2json():
161
  )
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  with gr.Blocks() as demo:
165
  gr.Markdown(
166
  title_md,
@@ -214,17 +245,23 @@ with gr.Blocks() as demo:
214
  audio_output = gr.Audio(
215
  type="numpy", label="Output Audio", interactive=False, loop=True
216
  )
217
- json_output = gr.JSON(label="Effect Settings", max_height=800, open=True)
 
 
 
 
 
 
218
 
219
  render_button.click(
220
- lambda *args: (lambda x: (x, model2json()))(inference(*args)),
221
  inputs=[
222
  audio_input,
223
  # random_rest_checkbox,
224
  ]
225
  # + sliders,
226
  ,
227
- outputs=[audio_output, json_output],
228
  )
229
 
230
  random_button.click(
@@ -235,16 +272,27 @@ with gr.Blocks() as demo:
235
  # )(normalvariate(0, 1))
236
  # for _ in range(len(xs))
237
  # ],
238
- lambda i: (lambda x: x[:NUMBER_OF_PCS].tolist() + [x[i - 1].item()])(
239
- z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX)
 
 
 
 
 
 
240
  ),
241
  inputs=extra_pc_dropdown,
242
- outputs=sliders + [extra_slider],
243
  )
244
  reset_button.click(
245
- lambda *xs: (lambda _: [0 for _ in range(len(xs))])(z.zero_()),
246
- inputs=sliders + [extra_slider],
247
- outputs=sliders + [extra_slider],
 
 
 
 
 
248
  )
249
 
250
  def update_z(s, i):
@@ -252,12 +300,26 @@ with gr.Blocks() as demo:
252
  return
253
 
254
  for i, slider in enumerate(sliders):
255
- slider.change(partial(update_z, i=i), inputs=slider)
256
- extra_slider.change(
257
- lambda _, i: update_z(_, i - 1), inputs=[extra_slider, extra_pc_dropdown]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  )
259
 
260
- extra_pc_dropdown.change(
261
  lambda i: z[i - 1].item(),
262
  inputs=extra_pc_dropdown,
263
  outputs=extra_slider,
 
1
  import gradio as gr
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
  import torch
5
  import yaml
6
  import json
7
  import pyloudnorm as pyln
8
  from hydra.utils import instantiate
 
9
  from soxr import resample
10
  from functools import partial
11
 
12
  from modules.utils import chain_functions, vec2statedict, get_chunks
13
  from modules.fx import clip_delay_eq_Q
14
+ from plot_utils import get_log_mags_from_eq
15
 
16
 
17
  title_md = "# Vocal Effects Generator"
 
42
  CONFIG_PATH = "presets/rt_config.yaml"
43
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
44
  INFO_PATH = "presets/internal/info.json"
45
+ MASK_PATH = "presets/internal/feature_mask.npy"
46
 
47
 
48
  with open(CONFIG_PATH) as fp:
49
  fx_config = yaml.safe_load(fp)["model"]
50
 
51
+ # Global effect
52
  fx = instantiate(fx_config)
53
  fx.eval()
54
 
 
61
  U = eigvecs * np.sqrt(eigvals)
62
  U = torch.from_numpy(U).float()
63
  mean = torch.from_numpy(mean).float()
64
+ feature_mask = torch.from_numpy(np.load(MASK_PATH))
65
+ # Global latent variable
66
  z = torch.zeros(75)
67
 
68
  with open(INFO_PATH) as f:
 
90
  )
91
  ),
92
  )
93
+ fx.load_state_dict(vec2dict(mean), strict=False)
94
 
95
 
96
  meter = pyln.Meter(44100)
97
 
98
 
99
+ @torch.no_grad()
100
+ def z2fx():
101
+ # close all figures to avoid too many open figures
102
+ plt.close("all")
103
+ x = U @ z + mean
104
+ # print(z)
105
+ fx.load_state_dict(vec2dict(x), strict=False)
106
+ return
107
+
108
+
109
+ def fx2z(func):
110
+ @torch.no_grad()
111
+ def wrapper(*args, **kwargs):
112
+ ret = func(*args, **kwargs)
113
+ state_dict = fx.state_dict()
114
+ flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
115
+ x = flattened[feature_mask]
116
+ z.copy_(U.T @ (x - mean))
117
+ return ret
118
+
119
+ return wrapper
120
+
121
+
122
  @torch.no_grad()
123
  def inference(audio):
124
  sr, y = audio
 
136
  if y.shape[1] != 1:
137
  y = y.mean(dim=1, keepdim=True)
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  fx.apply(partial(clip_delay_eq_Q, Q=0.707))
140
 
141
  rendered = fx(y).squeeze(0).T.numpy()
 
175
  )
176
 
177
 
178
+ @torch.no_grad()
179
+ def plot_eq():
180
+ fig, ax = plt.subplots(figsize=(8, 4))
181
+ w, eq_log_mags = get_log_mags_from_eq(fx[:6])
182
+ ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
183
+ for i, eq_log_mag in enumerate(eq_log_mags):
184
+ ax.plot(w, eq_log_mag, "k-", alpha=0.3)
185
+ ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1)
186
+ ax.set_xlabel("Frequency (Hz)")
187
+ ax.set_ylabel("Magnitude (dB)")
188
+ ax.set_xlim(20, 20000)
189
+ ax.set_ylim(-40, 20)
190
+ ax.set_xscale("log")
191
+ ax.grid()
192
+ return fig
193
+
194
+
195
  with gr.Blocks() as demo:
196
  gr.Markdown(
197
  title_md,
 
245
  audio_output = gr.Audio(
246
  type="numpy", label="Output Audio", interactive=False, loop=True
247
  )
248
+
249
+ peq_plot = gr.Plot(
250
+ plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
251
+ )
252
+
253
+ with gr.Row():
254
+ json_output = gr.JSON(label="Effect Settings", max_height=800, open=True)
255
 
256
  render_button.click(
257
+ lambda *args: (lambda x: (x, model2json(), plot_eq()))(inference(*args)),
258
  inputs=[
259
  audio_input,
260
  # random_rest_checkbox,
261
  ]
262
  # + sliders,
263
  ,
264
+ outputs=[audio_output, json_output, peq_plot],
265
  )
266
 
267
  random_button.click(
 
272
  # )(normalvariate(0, 1))
273
  # for _ in range(len(xs))
274
  # ],
275
+ # lambda i: (lambda x: x[:NUMBER_OF_PCS].tolist() + [x[i - 1].item()])(
276
+ # z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX)
277
+ # ),
278
+ chain_functions(
279
+ lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i),
280
+ lambda args: args + (z2fx(),),
281
+ lambda args: args[0][:NUMBER_OF_PCS].tolist()
282
+ + [args[0][args[1] - 1].item(), plot_eq()],
283
  ),
284
  inputs=extra_pc_dropdown,
285
+ outputs=sliders + [extra_slider, peq_plot],
286
  )
287
  reset_button.click(
288
+ # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
289
+ lambda: chain_functions(
290
+ lambda _: z.zero_(),
291
+ lambda _: z2fx(),
292
+ lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)] + [plot_eq()],
293
+ )(None),
294
+ # inputs=sliders + [extra_slider],
295
+ outputs=sliders + [extra_slider, peq_plot],
296
  )
297
 
298
  def update_z(s, i):
 
300
  return
301
 
302
  for i, slider in enumerate(sliders):
303
+ slider.input(
304
+ chain_functions(
305
+ partial(update_z, i=i),
306
+ lambda _: z2fx(),
307
+ lambda _: plot_eq(),
308
+ ),
309
+ inputs=slider,
310
+ outputs=peq_plot,
311
+ )
312
+ extra_slider.input(
313
+ lambda *xs: chain_functions(
314
+ lambda args: update_z(args[0], args[1] - 1),
315
+ lambda _: z2fx(),
316
+ lambda _: plot_eq(),
317
+ )(xs),
318
+ inputs=[extra_slider, extra_pc_dropdown],
319
+ outputs=peq_plot,
320
  )
321
 
322
+ extra_pc_dropdown.input(
323
  lambda i: z[i - 1].item(),
324
  inputs=extra_pc_dropdown,
325
  outputs=extra_slider,