yoyolicoris commited on
Commit
c1b7223
·
1 Parent(s): 1291f86

feat: adjustable PEQ

Browse files
Files changed (1) hide show
  1. app.py +266 -24
app.py CHANGED
@@ -37,7 +37,7 @@ For example:
37
 
38
  SLIDER_MAX = 3
39
  SLIDER_MIN = -3
40
- NUMBER_OF_PCS = 10
41
  TEMPERATURE = 0.7
42
  CONFIG_PATH = "presets/rt_config.yaml"
43
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
@@ -56,14 +56,14 @@ pca_params = np.load(PCA_PARAM_FILE)
56
  mean = pca_params["mean"]
57
  cov = pca_params["cov"]
58
  eigvals, eigvecs = np.linalg.eigh(cov)
59
- eigvals = np.flip(eigvals, axis=0)[:75]
60
- eigvecs = np.flip(eigvecs, axis=1)[:, :75]
61
  U = eigvecs * np.sqrt(eigvals)
62
  U = torch.from_numpy(U).float()
63
  mean = torch.from_numpy(mean).float()
64
  feature_mask = torch.from_numpy(np.load(MASK_PATH))
65
  # Global latent variable
66
- z = torch.zeros(75)
67
 
68
  with open(INFO_PATH) as f:
69
  info = json.load(f)
@@ -107,17 +107,13 @@ def z2fx():
107
  return
108
 
109
 
110
- def fx2z(func):
111
- @torch.no_grad()
112
- def wrapper(*args, **kwargs):
113
- ret = func(*args, **kwargs)
114
- state_dict = fx.state_dict()
115
- flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
116
- x = flattened[feature_mask]
117
- z.copy_(U.T @ (x - mean))
118
- return ret
119
-
120
- return wrapper
121
 
122
 
123
  @torch.no_grad()
@@ -166,12 +162,15 @@ def model2json():
166
  },
167
  "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(),
168
  }
169
- return json.dumps(
170
- {
171
- "Direct": results,
172
- "Sends": spatial_fx,
173
- }
174
  )
 
 
 
 
175
 
176
 
177
  @torch.no_grad()
@@ -283,7 +282,7 @@ def plot_t60():
283
  gamma = fdn.params.gamma.squeeze().numpy()
284
  delays = fdn.delays.numpy()
285
  w = np.linspace(0, 22050, gamma.size)
286
- t60 = -60 / (20 * np.log10(gamma) / np.min(delays)) / 44100
287
  ax.plot(w, t60, color="black", linestyle="-")
288
  ax.set_xlabel("Frequency (Hz)")
289
  ax.set_ylabel("T60 (s)")
@@ -294,6 +293,15 @@ def plot_t60():
294
  return fig
295
 
296
 
 
 
 
 
 
 
 
 
 
297
  with gr.Blocks() as demo:
298
  gr.Markdown(
299
  title_md,
@@ -328,10 +336,43 @@ with gr.Blocks() as demo:
328
  # value=False,
329
  # elem_id="randomise-checkbox",
330
  # )
331
- sliders = get_important_pcs(NUMBER_OF_PCS, value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  extra_pc_dropdown = gr.Dropdown(
334
- list(range(NUMBER_OF_PCS + 1, 76)),
335
  label=f"PC > {NUMBER_OF_PCS}",
336
  info="Select which extra PC to adjust",
337
  interactive=True,
@@ -348,9 +389,85 @@ with gr.Blocks() as demo:
348
  type="numpy", label="Output Audio", interactive=False, loop=True
349
  )
350
 
 
351
  peq_plot = gr.Plot(
352
  plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
353
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  comp_plot = gr.Plot(
355
  plot_comp(), label="Compressor Curve", elem_id="comp-plot"
356
  )
@@ -367,6 +484,36 @@ with gr.Blocks() as demo:
367
  model2json(), label="Effect Settings", max_height=800, open=True
368
  )
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  render_button.click(
371
  lambda *args: (
372
  lambda x: (
@@ -419,6 +566,18 @@ with gr.Blocks() as demo:
419
  plot_delay(),
420
  plot_reverb(),
421
  plot_t60(),
 
 
 
 
 
 
 
 
 
 
 
 
422
  ],
423
  ),
424
  inputs=extra_pc_dropdown,
@@ -431,6 +590,18 @@ with gr.Blocks() as demo:
431
  delay_plot,
432
  reverb_plot,
433
  t60_plot,
 
 
 
 
 
 
 
 
 
 
 
 
434
  ],
435
  )
436
  reset_button.click(
@@ -446,6 +617,18 @@ with gr.Blocks() as demo:
446
  plot_delay(),
447
  plot_reverb(),
448
  plot_t60(),
 
 
 
 
 
 
 
 
 
 
 
 
449
  ],
450
  )(None),
451
  # inputs=sliders + [extra_slider],
@@ -458,6 +641,18 @@ with gr.Blocks() as demo:
458
  delay_plot,
459
  reverb_plot,
460
  t60_plot,
 
 
 
 
 
 
 
 
 
 
 
 
461
  ],
462
  )
463
 
@@ -477,6 +672,16 @@ with gr.Blocks() as demo:
477
  plot_delay(),
478
  plot_reverb(),
479
  plot_t60(),
 
 
 
 
 
 
 
 
 
 
480
  ),
481
  ),
482
  inputs=slider,
@@ -487,6 +692,16 @@ with gr.Blocks() as demo:
487
  delay_plot,
488
  reverb_plot,
489
  t60_plot,
 
 
 
 
 
 
 
 
 
 
490
  ],
491
  )
492
  extra_slider.input(
@@ -500,10 +715,37 @@ with gr.Blocks() as demo:
500
  plot_delay(),
501
  plot_reverb(),
502
  plot_t60(),
 
 
 
 
 
 
 
 
 
 
503
  ),
504
  )(xs),
505
  inputs=[extra_slider, extra_pc_dropdown],
506
- outputs=[json_output, peq_plot, comp_plot, delay_plot, reverb_plot, t60_plot],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  )
508
 
509
  extra_pc_dropdown.input(
 
37
 
38
  SLIDER_MAX = 3
39
  SLIDER_MIN = -3
40
+ NUMBER_OF_PCS = 4
41
  TEMPERATURE = 0.7
42
  CONFIG_PATH = "presets/rt_config.yaml"
43
  PCA_PARAM_FILE = "presets/internal/gaussian.npz"
 
56
  mean = pca_params["mean"]
57
  cov = pca_params["cov"]
58
  eigvals, eigvecs = np.linalg.eigh(cov)
59
+ eigvals = np.flip(eigvals, axis=0)
60
+ eigvecs = np.flip(eigvecs, axis=1)
61
  U = eigvecs * np.sqrt(eigvals)
62
  U = torch.from_numpy(U).float()
63
  mean = torch.from_numpy(mean).float()
64
  feature_mask = torch.from_numpy(np.load(MASK_PATH))
65
  # Global latent variable
66
+ z = torch.zeros_like(mean)
67
 
68
  with open(INFO_PATH) as f:
69
  info = json.load(f)
 
107
  return
108
 
109
 
110
+ @torch.no_grad()
111
+ def fx2z():
112
+ state_dict = fx.state_dict()
113
+ flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
114
+ x = flattened[feature_mask]
115
+ z.copy_(U.T @ (x - mean))
116
+ return
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
 
162
  },
163
  "Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(),
164
  }
165
+ replace_neg_inf = lambda d: (
166
+ {k: (replace_neg_inf(v) if v != -np.inf else -1e500) for k, v in d.items()}
167
+ if isinstance(d, dict)
168
+ else d
 
169
  )
170
+ return {
171
+ "Direct": results,
172
+ "Sends": spatial_fx,
173
+ }
174
 
175
 
176
  @torch.no_grad()
 
282
  gamma = fdn.params.gamma.squeeze().numpy()
283
  delays = fdn.delays.numpy()
284
  w = np.linspace(0, 22050, gamma.size)
285
+ t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100
286
  ax.plot(w, t60, color="black", linestyle="-")
287
  ax.set_xlabel("Frequency (Hz)")
288
  ax.set_ylabel("T60 (s)")
 
293
  return fig
294
 
295
 
296
+ @torch.no_grad()
297
+ def upatePEQ(eq, attr_name, value):
298
+ match type(getattr(eq.params, attr_name)):
299
+ case torch.nn.Parameter:
300
+ getattr(eq.params, attr_name).data.copy_(value)
301
+ case _:
302
+ setattr(eq.params, attr_name, torch.tensor(value))
303
+
304
+
305
  with gr.Blocks() as demo:
306
  gr.Markdown(
307
  title_md,
 
336
  # value=False,
337
  # elem_id="randomise-checkbox",
338
  # )
339
+ # sliders = get_important_pcs(NUMBER_OF_PCS, value=0)
340
+ with gr.Row():
341
+ s1 = gr.Slider(
342
+ minimum=SLIDER_MIN,
343
+ maximum=SLIDER_MAX,
344
+ label="PC 1",
345
+ value=0,
346
+ interactive=True,
347
+ )
348
+ s2 = gr.Slider(
349
+ minimum=SLIDER_MIN,
350
+ maximum=SLIDER_MAX,
351
+ label="PC 2",
352
+ value=0,
353
+ interactive=True,
354
+ )
355
+
356
+ with gr.Row():
357
+ s3 = gr.Slider(
358
+ minimum=SLIDER_MIN,
359
+ maximum=SLIDER_MAX,
360
+ label="PC 3",
361
+ value=0,
362
+ interactive=True,
363
+ )
364
+ s4 = gr.Slider(
365
+ minimum=SLIDER_MIN,
366
+ maximum=SLIDER_MAX,
367
+ label="PC 4",
368
+ value=0,
369
+ interactive=True,
370
+ )
371
+
372
+ sliders = [s1, s2, s3, s4]
373
 
374
  extra_pc_dropdown = gr.Dropdown(
375
+ list(range(NUMBER_OF_PCS + 1, mean.numel())),
376
  label=f"PC > {NUMBER_OF_PCS}",
377
  info="Select which extra PC to adjust",
378
  interactive=True,
 
389
  type="numpy", label="Output Audio", interactive=False, loop=True
390
  )
391
 
392
+ _ = gr.Markdown("## Parametric EQ")
393
  peq_plot = gr.Plot(
394
  plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
395
  )
396
+ with gr.Row():
397
+ with gr.Column():
398
+ _ = gr.Markdown("Peak filter 1")
399
+ pk1 = fx[0]
400
+ pk1_freq = gr.Slider(
401
+ minimum=33,
402
+ maximum=5400,
403
+ value=pk1.params.freq.item(),
404
+ interactive=True,
405
+ )
406
+ pk1_gain = gr.Slider(
407
+ minimum=-24,
408
+ maximum=24,
409
+ value=pk1.params.gain.item(),
410
+ interactive=True,
411
+ )
412
+ pk1_q = gr.Slider(
413
+ minimum=0.2,
414
+ maximum=20,
415
+ value=pk1.params.Q.item(),
416
+ interactive=True,
417
+ )
418
+ with gr.Column():
419
+ _ = gr.Markdown("Peak filter 2")
420
+ pk2 = fx[1]
421
+ pk2_freq = gr.Slider(
422
+ minimum=200,
423
+ maximum=17500,
424
+ value=pk2.params.freq.item(),
425
+ interactive=True,
426
+ )
427
+ pk2_gain = gr.Slider(
428
+ minimum=-24,
429
+ maximum=24,
430
+ value=pk2.params.gain.item(),
431
+ interactive=True,
432
+ )
433
+ pk2_q = gr.Slider(
434
+ minimum=0.2,
435
+ maximum=20,
436
+ value=pk2.params.Q.item(),
437
+ interactive=True,
438
+ )
439
+ with gr.Row():
440
+ with gr.Column():
441
+ _ = gr.Markdown("Low Shelf")
442
+ ls = fx[2]
443
+ ls_freq = gr.Slider(
444
+ minimum=30,
445
+ maximum=200,
446
+ value=ls.params.freq.item(),
447
+ interactive=True,
448
+ )
449
+ ls_gain = gr.Slider(
450
+ minimum=-24,
451
+ maximum=24,
452
+ value=ls.params.gain.item(),
453
+ interactive=True,
454
+ )
455
+ with gr.Column():
456
+ _ = gr.Markdown("High Shelf")
457
+ hs = fx[3]
458
+ hs_freq = gr.Slider(
459
+ minimum=750,
460
+ maximum=8300,
461
+ value=hs.params.freq.item(),
462
+ interactive=True,
463
+ )
464
+ hs_gain = gr.Slider(
465
+ minimum=-24,
466
+ maximum=24,
467
+ value=hs.params.gain.item(),
468
+ interactive=True,
469
+ )
470
+
471
  comp_plot = gr.Plot(
472
  plot_comp(), label="Compressor Curve", elem_id="comp-plot"
473
  )
 
484
  model2json(), label="Effect Settings", max_height=800, open=True
485
  )
486
 
487
+ for eq, s, attr_name in zip(
488
+ [fx[0]] * 3 + [fx[1]] * 3 + [fx[2]] * 2 + [fx[3]] * 2,
489
+ [
490
+ pk1_freq,
491
+ pk1_gain,
492
+ pk1_q,
493
+ pk2_freq,
494
+ pk2_gain,
495
+ pk2_q,
496
+ ls_freq,
497
+ ls_gain,
498
+ hs_freq,
499
+ hs_gain,
500
+ ],
501
+ ["freq", "gain", "Q"] * 2 + ["freq", "gain"] * 2,
502
+ ):
503
+ s.input(
504
+ lambda *args, eq=eq, attr_name=attr_name: chain_functions( # chain_functions(
505
+ lambda args: (upatePEQ(eq, attr_name, args[0]), args[1]),
506
+ lambda args: (fx2z(), args[1]),
507
+ lambda args: [plot_eq()]
508
+ + z[:NUMBER_OF_PCS].tolist()
509
+ + [z[args[1] - 1].item(), model2json()],
510
+ )(
511
+ args
512
+ ),
513
+ inputs=[s, extra_pc_dropdown],
514
+ outputs=[peq_plot] + sliders + [extra_slider, json_output],
515
+ )
516
+
517
  render_button.click(
518
  lambda *args: (
519
  lambda x: (
 
566
  plot_delay(),
567
  plot_reverb(),
568
  plot_t60(),
569
+ ]
570
+ + [
571
+ pk1.params.freq.item(),
572
+ pk1.params.gain.item(),
573
+ pk1.params.Q.item(),
574
+ pk2.params.freq.item(),
575
+ pk2.params.gain.item(),
576
+ pk2.params.Q.item(),
577
+ ls.params.freq.item(),
578
+ ls.params.gain.item(),
579
+ hs.params.freq.item(),
580
+ hs.params.gain.item(),
581
  ],
582
  ),
583
  inputs=extra_pc_dropdown,
 
590
  delay_plot,
591
  reverb_plot,
592
  t60_plot,
593
+ ]
594
+ + [
595
+ pk1_freq,
596
+ pk1_gain,
597
+ pk1_q,
598
+ pk2_freq,
599
+ pk2_gain,
600
+ pk2_q,
601
+ ls_freq,
602
+ ls_gain,
603
+ hs_freq,
604
+ hs_gain,
605
  ],
606
  )
607
  reset_button.click(
 
617
  plot_delay(),
618
  plot_reverb(),
619
  plot_t60(),
620
+ ]
621
+ + [
622
+ pk1.params.freq.item(),
623
+ pk1.params.gain.item(),
624
+ pk1.params.Q.item(),
625
+ pk2.params.freq.item(),
626
+ pk2.params.gain.item(),
627
+ pk2.params.Q.item(),
628
+ ls.params.freq.item(),
629
+ ls.params.gain.item(),
630
+ hs.params.freq.item(),
631
+ hs.params.gain.item(),
632
  ],
633
  )(None),
634
  # inputs=sliders + [extra_slider],
 
641
  delay_plot,
642
  reverb_plot,
643
  t60_plot,
644
+ ]
645
+ + [
646
+ pk1_freq,
647
+ pk1_gain,
648
+ pk1_q,
649
+ pk2_freq,
650
+ pk2_gain,
651
+ pk2_q,
652
+ ls_freq,
653
+ ls_gain,
654
+ hs_freq,
655
+ hs_gain,
656
  ],
657
  )
658
 
 
672
  plot_delay(),
673
  plot_reverb(),
674
  plot_t60(),
675
+ pk1.params.freq.item(),
676
+ pk1.params.gain.item(),
677
+ pk1.params.Q.item(),
678
+ pk2.params.freq.item(),
679
+ pk2.params.gain.item(),
680
+ pk2.params.Q.item(),
681
+ ls.params.freq.item(),
682
+ ls.params.gain.item(),
683
+ hs.params.freq.item(),
684
+ hs.params.gain.item(),
685
  ),
686
  ),
687
  inputs=slider,
 
692
  delay_plot,
693
  reverb_plot,
694
  t60_plot,
695
+ pk1_freq,
696
+ pk1_gain,
697
+ pk1_q,
698
+ pk2_freq,
699
+ pk2_gain,
700
+ pk2_q,
701
+ ls_freq,
702
+ ls_gain,
703
+ hs_freq,
704
+ hs_gain,
705
  ],
706
  )
707
  extra_slider.input(
 
715
  plot_delay(),
716
  plot_reverb(),
717
  plot_t60(),
718
+ pk1.params.freq.item(),
719
+ pk1.params.gain.item(),
720
+ pk1.params.Q.item(),
721
+ pk2.params.freq.item(),
722
+ pk2.params.gain.item(),
723
+ pk2.params.Q.item(),
724
+ ls.params.freq.item(),
725
+ ls.params.gain.item(),
726
+ hs.params.freq.item(),
727
+ hs.params.gain.item(),
728
  ),
729
  )(xs),
730
  inputs=[extra_slider, extra_pc_dropdown],
731
+ outputs=[
732
+ json_output,
733
+ peq_plot,
734
+ comp_plot,
735
+ delay_plot,
736
+ reverb_plot,
737
+ t60_plot,
738
+ pk1_freq,
739
+ pk1_gain,
740
+ pk1_q,
741
+ pk2_freq,
742
+ pk2_gain,
743
+ pk2_q,
744
+ ls_freq,
745
+ ls_gain,
746
+ hs_freq,
747
+ hs_gain,
748
+ ],
749
  )
750
 
751
  extra_pc_dropdown.input(