yoyolicoris commited on
Commit
a1b4214
·
1 Parent(s): 466b233

feat: refactor global effect handling, enable user session states

Browse files
Files changed (1) hide show
  1. app.py +296 -209
app.py CHANGED
@@ -9,6 +9,7 @@ from hydra.utils import instantiate
9
  from soxr import resample
10
  from functools import partial
11
  from torchcomp import coef2ms, ms2coef
 
12
 
13
  from modules.utils import chain_functions, vec2statedict, get_chunks
14
  from modules.fx import clip_delay_eq_Q
@@ -50,8 +51,8 @@ with open(CONFIG_PATH) as fp:
50
  fx_config = yaml.safe_load(fp)["model"]
51
 
52
  # Global effect
53
- fx = instantiate(fx_config)
54
- fx.eval()
55
 
56
  pca_params = np.load(PCA_PARAM_FILE)
57
  mean = pca_params["mean"]
@@ -64,7 +65,7 @@ U = torch.from_numpy(U).float()
64
  mean = torch.from_numpy(mean).float()
65
  feature_mask = torch.from_numpy(np.load(MASK_PATH))
66
  # Global latent variable
67
- z = torch.zeros_like(mean)
68
 
69
  with open(INFO_PATH) as f:
70
  info = json.load(f)
@@ -91,35 +92,40 @@ vec2dict = partial(
91
  )
92
  ),
93
  )
94
- fx.load_state_dict(vec2dict(mean), strict=False)
95
 
96
 
97
  meter = pyln.Meter(44100)
98
 
99
 
100
  @torch.no_grad()
101
- def z2fx():
102
  # close all figures to avoid too many open figures
103
  plt.close("all")
104
  x = U @ z + mean
105
- # print(z)
106
- fx.load_state_dict(vec2dict(x), strict=False)
107
- fx.apply(partial(clip_delay_eq_Q, Q=0.707))
108
- return
109
 
110
 
111
  @torch.no_grad()
112
- def fx2z():
113
  plt.close("all")
114
  state_dict = fx.state_dict()
115
  flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
116
  x = flattened[feature_mask]
117
- z.copy_(U.T @ (x - mean))
118
- return
 
 
 
 
 
119
 
120
 
121
  @torch.no_grad()
122
- def inference(audio):
123
  sr, y = audio
124
  if sr != 44100:
125
  y = resample(y, sr, 44100)
@@ -163,7 +169,7 @@ def get_important_pcs(n=10, **kwargs):
163
  return sliders
164
 
165
 
166
- def model2json():
167
  fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"]
168
  results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | {
169
  "Panner": fx[7].pan.toJSON()
@@ -190,7 +196,7 @@ def model2json():
190
 
191
 
192
  @torch.no_grad()
193
- def plot_eq():
194
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
195
  w, eq_log_mags = get_log_mags_from_eq(fx[:6])
196
  ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
@@ -207,14 +213,14 @@ def plot_eq():
207
 
208
 
209
  @torch.no_grad()
210
- def plot_comp():
211
  fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
212
  comp = fx[6]
213
- cmp_th = comp.params.cmp_th.item()
214
- exp_th = comp.params.exp_th.item()
215
- cmp_ratio = comp.params.cmp_ratio.item()
216
- exp_ratio = comp.params.exp_ratio.item()
217
- make_up = comp.params.make_up.item()
218
  # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
219
 
220
  comp_in = np.linspace(-80, 0, 100)
@@ -242,16 +248,16 @@ def plot_comp():
242
 
243
 
244
  @torch.no_grad()
245
- def plot_delay():
246
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
247
  delay = fx[7].effects[0]
248
- w, eq_log_mags = get_log_mags_from_eq([delay.eq])
249
- log_gain = delay.params.gain.log10().item() * 20
250
- d = delay.params.delay.item() / 1000
251
  log_mag = sum(eq_log_mags)
252
  ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
253
 
254
- log_feedback = delay.params.feedback.log10().item() * 20
255
  for i in range(1, 10):
256
  feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
257
  ax.plot(
@@ -272,7 +278,7 @@ def plot_delay():
272
 
273
 
274
  @torch.no_grad()
275
- def plot_reverb():
276
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
277
  fdn = fx[7].effects[1]
278
  w, eq_log_mags = get_log_mags_from_eq(fdn.eq)
@@ -292,7 +298,7 @@ def plot_reverb():
292
 
293
 
294
  @torch.no_grad()
295
- def plot_t60():
296
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
297
  fdn = fx[7].effects[1]
298
  gamma = fdn.params.gamma.squeeze().numpy()
@@ -311,19 +317,39 @@ def plot_t60():
311
 
312
  @torch.no_grad()
313
  def update_param(m, attr_name, value):
314
- match type(getattr(m.params, attr_name)):
315
  case torch.nn.Parameter:
316
- getattr(m.params, attr_name).data.copy_(value)
317
  case _:
318
- setattr(m.params, attr_name, torch.tensor(value))
319
 
320
 
321
  @torch.no_grad()
322
  def update_atrt(comp, attr_name, value):
323
- setattr(comp.params, attr_name, ms2coef(torch.tensor(value), 44100))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
 
326
  with gr.Blocks() as demo:
 
 
 
 
327
  gr.Markdown(
328
  title_md,
329
  elem_id="title",
@@ -352,12 +378,6 @@ with gr.Blocks() as demo:
352
  render_button = gr.Button(
353
  "Run", elem_id="render-button", variant="primary"
354
  )
355
- # random_rest_checkbox = gr.Checkbox(
356
- # label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)",
357
- # value=False,
358
- # elem_id="randomise-checkbox",
359
- # )
360
- # sliders = get_important_pcs(NUMBER_OF_PCS, value=0)
361
  with gr.Row():
362
  s1 = gr.Slider(
363
  minimum=SLIDER_MIN,
@@ -417,7 +437,7 @@ with gr.Blocks() as demo:
417
  )
418
 
419
  _ = gr.Markdown("## Parametric EQ")
420
- peq_plot = gr.Plot(plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot")
421
  with gr.Row():
422
  with gr.Column(min_width=160):
423
  _ = gr.Markdown("High Pass")
@@ -425,14 +445,14 @@ with gr.Blocks() as demo:
425
  hp_freq = gr.Slider(
426
  minimum=16,
427
  maximum=5300,
428
- value=hp.params.freq.item(),
429
  interactive=True,
430
  label="Frequency (Hz)",
431
  )
432
  hp_q = gr.Slider(
433
  minimum=0.5,
434
  maximum=10,
435
- value=hp.params.Q.item(),
436
  interactive=True,
437
  label="Q",
438
  )
@@ -443,14 +463,14 @@ with gr.Blocks() as demo:
443
  ls_freq = gr.Slider(
444
  minimum=30,
445
  maximum=200,
446
- value=ls.params.freq.item(),
447
  interactive=True,
448
  label="Frequency (Hz)",
449
  )
450
  ls_gain = gr.Slider(
451
  minimum=-12,
452
  maximum=12,
453
- value=ls.params.gain.item(),
454
  interactive=True,
455
  label="Gain (dB)",
456
  )
@@ -461,21 +481,21 @@ with gr.Blocks() as demo:
461
  pk1_freq = gr.Slider(
462
  minimum=33,
463
  maximum=5400,
464
- value=pk1.params.freq.item(),
465
  interactive=True,
466
  label="Frequency (Hz)",
467
  )
468
  pk1_gain = gr.Slider(
469
  minimum=-12,
470
  maximum=12,
471
- value=pk1.params.gain.item(),
472
  interactive=True,
473
  label="Gain (dB)",
474
  )
475
  pk1_q = gr.Slider(
476
  minimum=0.2,
477
  maximum=20,
478
- value=pk1.params.Q.item(),
479
  interactive=True,
480
  label="Q",
481
  )
@@ -485,21 +505,21 @@ with gr.Blocks() as demo:
485
  pk2_freq = gr.Slider(
486
  minimum=200,
487
  maximum=17500,
488
- value=pk2.params.freq.item(),
489
  interactive=True,
490
  label="Frequency (Hz)",
491
  )
492
  pk2_gain = gr.Slider(
493
  minimum=-12,
494
  maximum=12,
495
- value=pk2.params.gain.item(),
496
  interactive=True,
497
  label="Gain (dB)",
498
  )
499
  pk2_q = gr.Slider(
500
  minimum=0.2,
501
  maximum=20,
502
- value=pk2.params.Q.item(),
503
  interactive=True,
504
  label="Q",
505
  )
@@ -510,14 +530,14 @@ with gr.Blocks() as demo:
510
  hs_freq = gr.Slider(
511
  minimum=750,
512
  maximum=8300,
513
- value=hs.params.freq.item(),
514
  interactive=True,
515
  label="Frequency (Hz)",
516
  )
517
  hs_gain = gr.Slider(
518
  minimum=-12,
519
  maximum=12,
520
- value=hs.params.gain.item(),
521
  interactive=True,
522
  label="Gain (dB)",
523
  )
@@ -527,14 +547,14 @@ with gr.Blocks() as demo:
527
  lp_freq = gr.Slider(
528
  minimum=200,
529
  maximum=18000,
530
- value=lp.params.freq.item(),
531
  interactive=True,
532
  label="Frequency (Hz)",
533
  )
534
  lp_q = gr.Slider(
535
  minimum=0.5,
536
  maximum=10,
537
- value=lp.params.Q.item(),
538
  interactive=True,
539
  label="Q",
540
  )
@@ -546,55 +566,55 @@ with gr.Blocks() as demo:
546
  cmp_th = gr.Slider(
547
  minimum=-60,
548
  maximum=0,
549
- value=comp.params.cmp_th.item(),
550
  interactive=True,
551
- label="Comp. Threshold (dB)",
552
  )
553
  cmp_ratio = gr.Slider(
554
  minimum=1,
555
  maximum=20,
556
- value=comp.params.cmp_ratio.item(),
557
  interactive=True,
558
- label="Comp. Ratio",
559
  )
560
  make_up = gr.Slider(
561
  minimum=-12,
562
  maximum=12,
563
- value=comp.params.make_up.item(),
564
  interactive=True,
565
  label="Make Up (dB)",
566
  )
567
  attack_time = gr.Slider(
568
  minimum=0.1,
569
  maximum=100,
570
- value=coef2ms(comp.params.at, 44100).item(),
571
  interactive=True,
572
  label="Attack Time (ms)",
573
  )
574
  release_time = gr.Slider(
575
  minimum=50,
576
  maximum=1000,
577
- value=coef2ms(comp.params.rt, 44100).item(),
578
  interactive=True,
579
  label="Release Time (ms)",
580
  )
581
  exp_ratio = gr.Slider(
582
  minimum=0,
583
  maximum=1,
584
- value=comp.params.exp_ratio.item(),
585
  interactive=True,
586
  label="Exp. Ratio",
587
  )
588
  exp_th = gr.Slider(
589
  minimum=-80,
590
  maximum=0,
591
- value=comp.params.exp_th.item(),
592
  interactive=True,
593
  label="Exp. Threshold (dB)",
594
  )
595
  with gr.Column():
596
  comp_plot = gr.Plot(
597
- plot_comp(), label="Compressor Curve", elem_id="comp-plot"
598
  )
599
 
600
  _ = gr.Markdown("## Ping-Pong Delay")
@@ -604,160 +624,215 @@ with gr.Blocks() as demo:
604
  delay_time = gr.Slider(
605
  minimum=100,
606
  maximum=1000,
607
- value=delay.params.delay.item(),
608
  interactive=True,
609
  label="Delay Time (ms)",
610
  )
611
  feedback = gr.Slider(
612
  minimum=0,
613
  maximum=1,
614
- value=delay.params.feedback.item(),
615
  interactive=True,
616
  label="Feedback",
617
  )
618
  delay_gain = gr.Slider(
619
  minimum=-80,
620
  maximum=0,
621
- value=delay.params.gain.log10().item() * 20,
622
  interactive=True,
623
  label="Gain (dB)",
624
  )
625
  odd_pan = gr.Slider(
626
  minimum=-100,
627
  maximum=100,
628
- value=delay.odd_pan.params.pan.item() * 200 - 100,
629
  interactive=True,
630
  label="Odd Delay Pan",
631
  )
632
  even_pan = gr.Slider(
633
  minimum=-100,
634
  maximum=100,
635
- value=delay.even_pan.params.pan.item() * 200 - 100,
636
  interactive=True,
637
  label="Even Delay Pan",
638
  )
639
  delay_lp_freq = gr.Slider(
640
  minimum=200,
641
  maximum=16000,
642
- value=delay.eq.params.freq.item(),
643
  interactive=True,
644
  label="Low Pass Frequency (Hz)",
645
  )
646
  with gr.Column():
647
  delay_plot = gr.Plot(
648
- plot_delay(), label="Delay Frequency Response", elem_id="delay-plot"
649
  )
650
 
651
  with gr.Row():
652
  reverb_plot = gr.Plot(
653
- plot_reverb(),
654
  label="Reverb Tone Correction PEQ",
655
  elem_id="reverb-plot",
656
  min_width=160,
657
  )
658
  t60_plot = gr.Plot(
659
- plot_t60(), label="Reverb T60", elem_id="t60-plot", min_width=160
660
  )
661
 
662
  with gr.Row():
663
  json_output = gr.JSON(
664
- model2json(), label="Effect Settings", max_height=800, open=True
665
  )
666
 
667
- update_pc = lambda i: z[:NUMBER_OF_PCS].tolist() + [z[i - 1].item()]
668
  update_pc_outputs = sliders + [extra_slider]
669
 
670
- for eq, s, attr_name in zip(
671
- [fx[0]] * 3
672
- + [fx[1]] * 3
673
- + [fx[2]] * 2
674
- + [fx[3]] * 2
675
- + [fx[4]] * 2
676
- + [fx[5]] * 2,
677
- [
678
- pk1_freq,
679
- pk1_gain,
680
- pk1_q,
681
- pk2_freq,
682
- pk2_gain,
683
- pk2_q,
684
- ls_freq,
685
- ls_gain,
686
- hs_freq,
687
- hs_gain,
688
- lp_freq,
689
- lp_q,
690
- hp_freq,
691
- hp_q,
692
- ],
693
- ["freq", "gain", "Q"] * 2 + ["freq", "gain"] * 2 + ["freq", "Q"] * 2,
694
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  s.input(
696
- lambda *args, eq=eq, attr_name=attr_name: chain_functions( # chain_functions(
697
- lambda args: (update_param(eq, attr_name, args[0]), args[1]),
698
- lambda args: (fx2z(), args[1]),
699
- lambda args: args[1],
700
- lambda i: update_pc(i) + [model2json(), plot_eq()],
 
 
 
 
 
 
 
701
  )(
702
  args
703
  ),
704
- inputs=[s, extra_pc_dropdown],
705
- outputs=update_pc_outputs + [json_output, peq_plot],
706
  )
707
 
708
- for f, s, attr_name in zip(
709
- [update_param] * 5 + [update_atrt] * 2,
710
- [
711
- cmp_th,
712
- cmp_ratio,
713
- make_up,
714
- exp_ratio,
715
- exp_th,
716
- attack_time,
717
- release_time,
718
- ],
719
- ["cmp_th", "cmp_ratio", "make_up", "exp_ratio", "exp_th", "at", "rt"],
720
- ):
721
  s.input(
722
  lambda *args, attr_name=attr_name, f=f: chain_functions(
723
- lambda args: (f(comp, attr_name, args[0]), args[1]),
724
- lambda args: (fx2z(), args[1]),
725
- lambda args: args[1],
726
- lambda i: update_pc(i) + [model2json(), plot_comp()],
 
 
 
 
 
 
 
727
  )(args),
728
- inputs=[s, extra_pc_dropdown],
729
- outputs=update_pc_outputs + [json_output, comp_plot],
730
  )
731
 
732
- for m, f, s, attr_name, update_plot in zip(
733
- [delay] * 2 + [delay.eq] + [delay, delay.odd_pan, delay.even_pan],
734
- [update_param] * 3
735
- + [
736
- lambda m, a, v: update_param(m, a, 10 ** (v / 20)),
737
- lambda m, a, v: update_param(m, a, (v + 100) / 200),
738
- lambda m, a, v: update_param(m, a, (v + 100) / 200),
739
- ],
740
- [delay_time, feedback, delay_lp_freq, delay_gain, odd_pan, even_pan],
741
- ["delay", "feedback", "freq", "gain", "pan", "pan"],
742
- [True] * 4 + [False] * 2,
743
  ):
744
  s.input(
745
- lambda *args, f=f, m=m, attr_name=attr_name, update_plot=update_plot: chain_functions(
746
- lambda args: (f(m, attr_name, args[0]), args[1]),
747
- lambda args: (fx2z(), args[1]),
748
- lambda args: args[1],
749
- lambda i: (
750
- update_pc(i)
751
- + [model2json()]
752
- + ([plot_delay()] if update_plot else [])
 
 
 
 
 
 
 
753
  ),
754
  )(
755
  args
756
  ),
757
- inputs=[s, extra_pc_dropdown],
758
- outputs=update_pc_outputs
759
  + [json_output]
760
- + ([delay_plot] if update_plot else []),
 
761
  )
762
 
763
  render_button.click(
@@ -767,10 +842,17 @@ with gr.Blocks() as demo:
767
  # model2json(),
768
  # )
769
  # )(inference(*args)),
770
- inference,
 
 
 
 
 
771
  inputs=[
772
  audio_input,
773
- ],
 
 
774
  outputs=[
775
  audio_output,
776
  direct_output,
@@ -778,34 +860,34 @@ with gr.Blocks() as demo:
778
  ],
779
  )
780
 
781
- update_fx = lambda: [
782
- pk1.params.freq.item(),
783
- pk1.params.gain.item(),
784
- pk1.params.Q.item(),
785
- pk2.params.freq.item(),
786
- pk2.params.gain.item(),
787
- pk2.params.Q.item(),
788
- ls.params.freq.item(),
789
- ls.params.gain.item(),
790
- hs.params.freq.item(),
791
- hs.params.gain.item(),
792
- lp.params.freq.item(),
793
- lp.params.Q.item(),
794
- hp.params.freq.item(),
795
- hp.params.Q.item(),
796
- comp.params.cmp_th.item(),
797
- comp.params.cmp_ratio.item(),
798
- comp.params.make_up.item(),
799
- comp.params.exp_th.item(),
800
- comp.params.exp_ratio.item(),
801
- coef2ms(comp.params.at, 44100).item(),
802
- coef2ms(comp.params.rt, 44100).item(),
803
- delay.params.delay.item(),
804
- delay.params.feedback.item(),
805
- delay.params.gain.log10().item() * 20,
806
- delay.eq.params.freq.item(),
807
- delay.odd_pan.params.pan.item() * 200 - 100,
808
- delay.even_pan.params.pan.item() * 200 - 100,
809
  ]
810
  update_fx_outputs = [
811
  pk1_freq,
@@ -836,12 +918,12 @@ with gr.Blocks() as demo:
836
  odd_pan,
837
  even_pan,
838
  ]
839
- update_plots = lambda: [
840
- plot_eq(),
841
- plot_comp(),
842
- plot_delay(),
843
- plot_reverb(),
844
- plot_t60(),
845
  ]
846
  update_plots_outputs = [
847
  peq_plot,
@@ -851,56 +933,61 @@ with gr.Blocks() as demo:
851
  t60_plot,
852
  ]
853
 
854
- update_all = lambda i: update_pc(i) + update_fx() + update_plots()
855
  update_all_outputs = update_pc_outputs + update_fx_outputs + update_plots_outputs
856
 
857
  random_button.click(
858
  chain_functions(
859
- lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i),
860
- lambda args: (z2fx(), args[1]),
861
- lambda args: args[1],
862
- update_all,
863
  ),
864
  inputs=extra_pc_dropdown,
865
- outputs=update_all_outputs,
866
  )
867
  reset_button.click(
868
  # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
869
  lambda: chain_functions(
870
- lambda _: z.zero_(),
871
- lambda _: z2fx(),
872
- lambda _: update_all(NUMBER_OF_PCS),
873
  )(None),
874
- outputs=update_all_outputs,
875
  )
876
 
877
- def update_z(s, i):
878
  z[i] = s
879
- return
880
 
881
  for i, slider in enumerate(sliders):
882
  slider.input(
883
- chain_functions(
884
- partial(update_z, i=i),
885
- lambda _: z2fx(),
886
- lambda _: update_fx() + update_plots() + [model2json()],
887
- ),
888
- inputs=slider,
889
- outputs=update_fx_outputs + update_plots_outputs + [json_output],
 
 
 
890
  )
891
  extra_slider.input(
892
  lambda *xs: chain_functions(
893
- lambda args: update_z(args[0], args[1] - 1),
894
- lambda _: z2fx(),
895
- lambda _: update_fx() + update_plots() + [model2json()],
 
 
 
896
  )(xs),
897
- inputs=[extra_slider, extra_pc_dropdown],
898
- outputs=update_fx_outputs + update_plots_outputs + [json_output],
899
  )
900
 
901
  extra_pc_dropdown.input(
902
- lambda i: z[i - 1].item(),
903
- inputs=extra_pc_dropdown,
904
  outputs=extra_slider,
905
  )
906
 
 
9
  from soxr import resample
10
  from functools import partial
11
  from torchcomp import coef2ms, ms2coef
12
+ from copy import deepcopy
13
 
14
  from modules.utils import chain_functions, vec2statedict, get_chunks
15
  from modules.fx import clip_delay_eq_Q
 
51
  fx_config = yaml.safe_load(fp)["model"]
52
 
53
  # Global effect
54
+ global_fx = instantiate(fx_config)
55
+ global_fx.eval()
56
 
57
  pca_params = np.load(PCA_PARAM_FILE)
58
  mean = pca_params["mean"]
 
65
  mean = torch.from_numpy(mean).float()
66
  feature_mask = torch.from_numpy(np.load(MASK_PATH))
67
  # Global latent variable
68
+ # z = torch.zeros_like(mean)
69
 
70
  with open(INFO_PATH) as f:
71
  info = json.load(f)
 
92
  )
93
  ),
94
  )
95
+ global_fx.load_state_dict(vec2dict(mean), strict=False)
96
 
97
 
98
  meter = pyln.Meter(44100)
99
 
100
 
101
  @torch.no_grad()
102
+ def z2x(z):
103
  # close all figures to avoid too many open figures
104
  plt.close("all")
105
  x = U @ z + mean
106
+ # # print(z)
107
+ # fx.load_state_dict(vec2dict(x), strict=False)
108
+ # fx.apply(partial(clip_delay_eq_Q, Q=0.707))
109
+ return x
110
 
111
 
112
  @torch.no_grad()
113
+ def fx2x(fx):
114
  plt.close("all")
115
  state_dict = fx.state_dict()
116
  flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
117
  x = flattened[feature_mask]
118
+ return x
119
+
120
+
121
+ @torch.no_grad()
122
+ def x2z(x):
123
+ z = U.T @ (x - mean)
124
+ return z
125
 
126
 
127
  @torch.no_grad()
128
+ def inference(audio, fx):
129
  sr, y = audio
130
  if sr != 44100:
131
  y = resample(y, sr, 44100)
 
169
  return sliders
170
 
171
 
172
+ def model2json(fx):
173
  fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"]
174
  results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | {
175
  "Panner": fx[7].pan.toJSON()
 
196
 
197
 
198
  @torch.no_grad()
199
+ def plot_eq(fx):
200
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
201
  w, eq_log_mags = get_log_mags_from_eq(fx[:6])
202
  ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
 
213
 
214
 
215
  @torch.no_grad()
216
+ def plot_comp(fx):
217
  fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
218
  comp = fx[6]
219
+ cmp_th = fx[6].params.cmp_th.item()
220
+ exp_th = fx[6].params.exp_th.item()
221
+ cmp_ratio = fx[6].params.cmp_ratio.item()
222
+ exp_ratio = fx[6].params.exp_ratio.item()
223
+ make_up = fx[6].params.make_up.item()
224
  # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
225
 
226
  comp_in = np.linspace(-80, 0, 100)
 
248
 
249
 
250
  @torch.no_grad()
251
+ def plot_delay(fx):
252
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
253
  delay = fx[7].effects[0]
254
+ w, eq_log_mags = get_log_mags_from_eq([fx[7].effects[0].eq])
255
+ log_gain = fx[7].effects[0].params.gain.log10().item() * 20
256
+ d = fx[7].effects[0].params.delay.item() / 1000
257
  log_mag = sum(eq_log_mags)
258
  ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
259
 
260
+ log_feedback = fx[7].effects[0].params.feedback.log10().item() * 20
261
  for i in range(1, 10):
262
  feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
263
  ax.plot(
 
278
 
279
 
280
  @torch.no_grad()
281
+ def plot_reverb(fx):
282
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
283
  fdn = fx[7].effects[1]
284
  w, eq_log_mags = get_log_mags_from_eq(fdn.eq)
 
298
 
299
 
300
  @torch.no_grad()
301
+ def plot_t60(fx):
302
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
303
  fdn = fx[7].effects[1]
304
  gamma = fdn.params.gamma.squeeze().numpy()
 
317
 
318
  @torch.no_grad()
319
  def update_param(m, attr_name, value):
320
+ match type(getattr(m, attr_name)):
321
  case torch.nn.Parameter:
322
+ getattr(m, attr_name).data.copy_(value)
323
  case _:
324
+ setattr(m, attr_name, torch.tensor(value))
325
 
326
 
327
  @torch.no_grad()
328
  def update_atrt(comp, attr_name, value):
329
+ setattr(comp, attr_name, ms2coef(torch.tensor(value), 44100))
330
+
331
+
332
+ def vec2fx(x):
333
+ fx = deepcopy(global_fx)
334
+ fx.load_state_dict(vec2dict(x), strict=False)
335
+ fx.apply(partial(clip_delay_eq_Q, Q=0.707))
336
+ return fx
337
+
338
+
339
+ get_last_attribute = lambda m, attr_name: (
340
+ (m, attr_name)
341
+ if "." not in attr_name
342
+ else (lambda x, *remain: get_last_attribute(getattr(m, x), ".".join(remain)))(
343
+ *attr_name.split(".")
344
+ )
345
+ )
346
 
347
 
348
  with gr.Blocks() as demo:
349
+ z = gr.State(torch.zeros_like(mean))
350
+ fx_params = gr.State(mean)
351
+ fx = vec2fx(fx_params.value)
352
+
353
  gr.Markdown(
354
  title_md,
355
  elem_id="title",
 
378
  render_button = gr.Button(
379
  "Run", elem_id="render-button", variant="primary"
380
  )
 
 
 
 
 
 
381
  with gr.Row():
382
  s1 = gr.Slider(
383
  minimum=SLIDER_MIN,
 
437
  )
438
 
439
  _ = gr.Markdown("## Parametric EQ")
440
+ peq_plot = gr.Plot(plot_eq(fx), label="PEQ Frequency Response", elem_id="peq-plot")
441
  with gr.Row():
442
  with gr.Column(min_width=160):
443
  _ = gr.Markdown("High Pass")
 
445
  hp_freq = gr.Slider(
446
  minimum=16,
447
  maximum=5300,
448
+ value=fx[5].params.freq.item(),
449
  interactive=True,
450
  label="Frequency (Hz)",
451
  )
452
  hp_q = gr.Slider(
453
  minimum=0.5,
454
  maximum=10,
455
+ value=fx[5].params.Q.item(),
456
  interactive=True,
457
  label="Q",
458
  )
 
463
  ls_freq = gr.Slider(
464
  minimum=30,
465
  maximum=200,
466
+ value=fx[2].params.freq.item(),
467
  interactive=True,
468
  label="Frequency (Hz)",
469
  )
470
  ls_gain = gr.Slider(
471
  minimum=-12,
472
  maximum=12,
473
+ value=fx[2].params.gain.item(),
474
  interactive=True,
475
  label="Gain (dB)",
476
  )
 
481
  pk1_freq = gr.Slider(
482
  minimum=33,
483
  maximum=5400,
484
+ value=fx[0].params.freq.item(),
485
  interactive=True,
486
  label="Frequency (Hz)",
487
  )
488
  pk1_gain = gr.Slider(
489
  minimum=-12,
490
  maximum=12,
491
+ value=fx[0].params.gain.item(),
492
  interactive=True,
493
  label="Gain (dB)",
494
  )
495
  pk1_q = gr.Slider(
496
  minimum=0.2,
497
  maximum=20,
498
+ value=fx[0].params.Q.item(),
499
  interactive=True,
500
  label="Q",
501
  )
 
505
  pk2_freq = gr.Slider(
506
  minimum=200,
507
  maximum=17500,
508
+ value=fx[1].params.freq.item(),
509
  interactive=True,
510
  label="Frequency (Hz)",
511
  )
512
  pk2_gain = gr.Slider(
513
  minimum=-12,
514
  maximum=12,
515
+ value=fx[1].params.gain.item(),
516
  interactive=True,
517
  label="Gain (dB)",
518
  )
519
  pk2_q = gr.Slider(
520
  minimum=0.2,
521
  maximum=20,
522
+ value=fx[1].params.Q.item(),
523
  interactive=True,
524
  label="Q",
525
  )
 
530
  hs_freq = gr.Slider(
531
  minimum=750,
532
  maximum=8300,
533
+ value=fx[3].params.freq.item(),
534
  interactive=True,
535
  label="Frequency (Hz)",
536
  )
537
  hs_gain = gr.Slider(
538
  minimum=-12,
539
  maximum=12,
540
+ value=fx[3].params.gain.item(),
541
  interactive=True,
542
  label="Gain (dB)",
543
  )
 
547
  lp_freq = gr.Slider(
548
  minimum=200,
549
  maximum=18000,
550
+ value=fx[4].params.freq.item(),
551
  interactive=True,
552
  label="Frequency (Hz)",
553
  )
554
  lp_q = gr.Slider(
555
  minimum=0.5,
556
  maximum=10,
557
+ value=fx[4].params.Q.item(),
558
  interactive=True,
559
  label="Q",
560
  )
 
566
  cmp_th = gr.Slider(
567
  minimum=-60,
568
  maximum=0,
569
+ value=fx[6].params.cmp_th.item(),
570
  interactive=True,
571
+ label="fx[6]. Threshold (dB)",
572
  )
573
  cmp_ratio = gr.Slider(
574
  minimum=1,
575
  maximum=20,
576
+ value=fx[6].params.cmp_ratio.item(),
577
  interactive=True,
578
+ label="fx[6]. Ratio",
579
  )
580
  make_up = gr.Slider(
581
  minimum=-12,
582
  maximum=12,
583
+ value=fx[6].params.make_up.item(),
584
  interactive=True,
585
  label="Make Up (dB)",
586
  )
587
  attack_time = gr.Slider(
588
  minimum=0.1,
589
  maximum=100,
590
+ value=coef2ms(fx[6].params.at, 44100).item(),
591
  interactive=True,
592
  label="Attack Time (ms)",
593
  )
594
  release_time = gr.Slider(
595
  minimum=50,
596
  maximum=1000,
597
+ value=coef2ms(fx[6].params.rt, 44100).item(),
598
  interactive=True,
599
  label="Release Time (ms)",
600
  )
601
  exp_ratio = gr.Slider(
602
  minimum=0,
603
  maximum=1,
604
+ value=fx[6].params.exp_ratio.item(),
605
  interactive=True,
606
  label="Exp. Ratio",
607
  )
608
  exp_th = gr.Slider(
609
  minimum=-80,
610
  maximum=0,
611
+ value=fx[6].params.exp_th.item(),
612
  interactive=True,
613
  label="Exp. Threshold (dB)",
614
  )
615
  with gr.Column():
616
  comp_plot = gr.Plot(
617
+ plot_comp(fx), label="Compressor Curve", elem_id="comp-plot"
618
  )
619
 
620
  _ = gr.Markdown("## Ping-Pong Delay")
 
624
  delay_time = gr.Slider(
625
  minimum=100,
626
  maximum=1000,
627
+ value=fx[7].effects[0].params.delay.item(),
628
  interactive=True,
629
  label="Delay Time (ms)",
630
  )
631
  feedback = gr.Slider(
632
  minimum=0,
633
  maximum=1,
634
+ value=fx[7].effects[0].params.feedback.item(),
635
  interactive=True,
636
  label="Feedback",
637
  )
638
  delay_gain = gr.Slider(
639
  minimum=-80,
640
  maximum=0,
641
+ value=fx[7].effects[0].params.gain.log10().item() * 20,
642
  interactive=True,
643
  label="Gain (dB)",
644
  )
645
  odd_pan = gr.Slider(
646
  minimum=-100,
647
  maximum=100,
648
+ value=fx[7].effects[0].odd_pan.params.pan.item() * 200 - 100,
649
  interactive=True,
650
  label="Odd Delay Pan",
651
  )
652
  even_pan = gr.Slider(
653
  minimum=-100,
654
  maximum=100,
655
+ value=fx[7].effects[0].even_pan.params.pan.item() * 200 - 100,
656
  interactive=True,
657
  label="Even Delay Pan",
658
  )
659
  delay_lp_freq = gr.Slider(
660
  minimum=200,
661
  maximum=16000,
662
+ value=fx[7].effects[0].eq.params.freq.item(),
663
  interactive=True,
664
  label="Low Pass Frequency (Hz)",
665
  )
666
  with gr.Column():
667
  delay_plot = gr.Plot(
668
+ plot_delay(fx), label="Delay Frequency Response", elem_id="delay-plot"
669
  )
670
 
671
  with gr.Row():
672
  reverb_plot = gr.Plot(
673
+ plot_reverb(fx),
674
  label="Reverb Tone Correction PEQ",
675
  elem_id="reverb-plot",
676
  min_width=160,
677
  )
678
  t60_plot = gr.Plot(
679
+ plot_t60(fx), label="Reverb T60", elem_id="t60-plot", min_width=160
680
  )
681
 
682
  with gr.Row():
683
  json_output = gr.JSON(
684
+ model2json(fx), label="Effect Settings", max_height=800, open=True
685
  )
686
 
687
+ update_pc = lambda z, i: z[:NUMBER_OF_PCS].tolist() + [z[i - 1].item()]
688
  update_pc_outputs = sliders + [extra_slider]
689
 
690
+ peq_sliders = [
691
+ pk1_freq,
692
+ pk1_gain,
693
+ pk1_q,
694
+ pk2_freq,
695
+ pk2_gain,
696
+ pk2_q,
697
+ ls_freq,
698
+ ls_gain,
699
+ hs_freq,
700
+ hs_gain,
701
+ lp_freq,
702
+ lp_q,
703
+ hp_freq,
704
+ hp_q,
705
+ ]
706
+ peq_attr_names = (
707
+ ["freq", "gain", "Q"] * 2 + ["freq", "gain"] * 2 + ["freq", "Q"] * 2
708
+ )
709
+ peq_indices = [0] * 3 + [1] * 3 + [2] * 2 + [3] * 2 + [4] * 2 + [5] * 2
710
+
711
+ cmp_sliders = [
712
+ cmp_th,
713
+ cmp_ratio,
714
+ make_up,
715
+ exp_ratio,
716
+ exp_th,
717
+ attack_time,
718
+ release_time,
719
+ ]
720
+ cmp_update_funcs = [update_param] * 5 + [update_atrt] * 2
721
+ cmp_attr_names = [
722
+ "cmp_th",
723
+ "cmp_ratio",
724
+ "make_up",
725
+ "exp_ratio",
726
+ "exp_th",
727
+ "at",
728
+ "rt",
729
+ ]
730
+
731
+ delay_sliders = [delay_time, feedback, delay_lp_freq, delay_gain, odd_pan, even_pan]
732
+ delay_update_funcs = [update_param] * 3 + [
733
+ lambda m, a, v: update_param(m, a, 10 ** (v / 20)),
734
+ lambda m, a, v: update_param(m, a, (v + 100) / 200),
735
+ lambda m, a, v: update_param(m, a, (v + 100) / 200),
736
+ ]
737
+ delay_attr_names = [
738
+ "params.delay",
739
+ "params.feedback",
740
+ "eq.params.freq",
741
+ "params.gain",
742
+ "odd_pan.params.pan",
743
+ "even_pan.params.pan",
744
+ ]
745
+ delay_update_plot_flag = [True] * 4 + [False] * 2
746
+
747
+ all_effect_sliders = peq_sliders + cmp_sliders + delay_sliders
748
+ split_sizes = [len(peq_sliders), len(cmp_sliders), len(delay_sliders)]
749
+
750
+ def assign_fx_params(fx, *args):
751
+ peq_sliders, cmp_sliders, delay_sliders = (
752
+ args[: split_sizes[0]],
753
+ args[split_sizes[0] : sum(split_sizes[:2])],
754
+ args[sum(split_sizes[:2]) :],
755
+ )
756
+ for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names):
757
+ update_param(fx[idx].params, attr_name, s)
758
+
759
+ for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names):
760
+ f(fx[6].params, attr_name, s)
761
+
762
+ for f, s, attr_name in zip(delay_update_funcs, delay_sliders, delay_attr_names):
763
+ m, name = get_last_attribute(fx[7].effects[0], attr_name)
764
+ f(m, name, s)
765
+
766
+ return fx
767
+
768
+ for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names):
769
  s.input(
770
+ lambda *args, idx=idx, attr_name=attr_name: chain_functions( # chain_functions(
771
+ lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
772
+ lambda args: (
773
+ update_param(args[0][idx].params, attr_name, args[1]),
774
+ args[0],
775
+ args[2],
776
+ ),
777
+ lambda args: (fx2x(args[1]), *args[1:]),
778
+ lambda args: [x2z(args[0]), *args],
779
+ lambda args: args[:2]
780
+ + [model2json(args[2]), plot_eq(args[2])]
781
+ + update_pc(args[0], args[3]),
782
  )(
783
  args
784
  ),
785
+ inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
786
+ outputs=[z, fx_params, json_output, peq_plot] + update_pc_outputs,
787
  )
788
 
789
+ for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names):
 
 
 
 
 
 
 
 
 
 
 
 
790
  s.input(
791
  lambda *args, attr_name=attr_name, f=f: chain_functions(
792
+ lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
793
+ lambda args: (
794
+ f(args[0][6].params, attr_name, args[1]),
795
+ args[0],
796
+ args[2],
797
+ ),
798
+ lambda args: (fx2x(args[1]), *args[1:]),
799
+ lambda args: [x2z(args[0]), *args],
800
+ lambda args: args[:2]
801
+ + [model2json(args[2]), plot_comp(args[2])]
802
+ + update_pc(args[0], args[3]),
803
  )(args),
804
+ inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
805
+ outputs=[z, fx_params, json_output, comp_plot] + update_pc_outputs,
806
  )
807
 
808
+ for f, s, attr_name, update_plot in zip(
809
+ delay_update_funcs, delay_sliders, delay_attr_names, delay_update_plot_flag
 
 
 
 
 
 
 
 
 
810
  ):
811
  s.input(
812
+ lambda *args, f=f, attr_name=attr_name, update_plot=update_plot: chain_functions(
813
+ lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
814
+ lambda args: (
815
+ # f(args[0][7].effects[0], attr_name, args[1]),
816
+ f(*get_last_attribute(args[0][7].effects[0], attr_name), args[1]),
817
+ args[0],
818
+ args[2],
819
+ ),
820
+ lambda args: (fx2x(args[1]), *args[1:]),
821
+ lambda args: [x2z(args[0]), *args],
822
+ lambda args: (
823
+ args[:2]
824
+ + [model2json(args[2])]
825
+ + ([plot_delay(args[2])] if update_plot else [])
826
+ + update_pc(args[0], args[3])
827
  ),
828
  )(
829
  args
830
  ),
831
+ inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
832
+ outputs=[z, fx_params]
833
  + [json_output]
834
+ + ([delay_plot] if update_plot else [])
835
+ + update_pc_outputs,
836
  )
837
 
838
  render_button.click(
 
842
  # model2json(),
843
  # )
844
  # )(inference(*args)),
845
+ # inference,
846
+ # lambda audio, x: inference(audio, vec2fx(x)),
847
+ lambda audio, *args: chain_functions(
848
+ lambda args: assign_fx_params(vec2fx(args[0]), *args[1:]),
849
+ partial(inference, audio),
850
+ )(args),
851
  inputs=[
852
  audio_input,
853
+ fx_params,
854
+ ]
855
+ + all_effect_sliders,
856
  outputs=[
857
  audio_output,
858
  direct_output,
 
860
  ],
861
  )
862
 
863
+ update_fx = lambda fx: [
864
+ fx[0].params.freq.item(),
865
+ fx[0].params.gain.item(),
866
+ fx[0].params.Q.item(),
867
+ fx[1].params.freq.item(),
868
+ fx[1].params.gain.item(),
869
+ fx[1].params.Q.item(),
870
+ fx[2].params.freq.item(),
871
+ fx[2].params.gain.item(),
872
+ fx[3].params.freq.item(),
873
+ fx[3].params.gain.item(),
874
+ fx[4].params.freq.item(),
875
+ fx[4].params.Q.item(),
876
+ fx[5].params.freq.item(),
877
+ fx[5].params.Q.item(),
878
+ fx[6].params.cmp_th.item(),
879
+ fx[6].params.cmp_ratio.item(),
880
+ fx[6].params.make_up.item(),
881
+ fx[6].params.exp_th.item(),
882
+ fx[6].params.exp_ratio.item(),
883
+ coef2ms(fx[6].params.at, 44100).item(),
884
+ coef2ms(fx[6].params.rt, 44100).item(),
885
+ fx[7].effects[0].params.delay.item(),
886
+ fx[7].effects[0].params.feedback.item(),
887
+ fx[7].effects[0].params.gain.log10().item() * 20,
888
+ fx[7].effects[0].eq.params.freq.item(),
889
+ fx[7].effects[0].odd_pan.params.pan.item() * 200 - 100,
890
+ fx[7].effects[0].even_pan.params.pan.item() * 200 - 100,
891
  ]
892
  update_fx_outputs = [
893
  pk1_freq,
 
918
  odd_pan,
919
  even_pan,
920
  ]
921
+ update_plots = lambda fx: [
922
+ plot_eq(fx),
923
+ plot_comp(fx),
924
+ plot_delay(fx),
925
+ plot_reverb(fx),
926
+ plot_t60(fx),
927
  ]
928
  update_plots_outputs = [
929
  peq_plot,
 
933
  t60_plot,
934
  ]
935
 
936
+ update_all = lambda z, fx, i: update_pc(z, i) + update_fx(fx) + update_plots(fx)
937
  update_all_outputs = update_pc_outputs + update_fx_outputs + update_plots_outputs
938
 
939
  random_button.click(
940
  chain_functions(
941
+ lambda i: (torch.randn_like(mean).clip(SLIDER_MIN, SLIDER_MAX), i),
942
+ lambda args: (args[0], vec2fx(z2x(args[0])), args[1]),
943
+ lambda args: update_all(*args) + [args[0]],
 
944
  ),
945
  inputs=extra_pc_dropdown,
946
+ outputs=update_all_outputs + [z],
947
  )
948
  reset_button.click(
949
  # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
950
  lambda: chain_functions(
951
+ lambda _: torch.zeros_like(mean),
952
+ lambda z: (z, vec2fx(z2x(z))),
953
+ lambda args: update_all(args[0], args[1], NUMBER_OF_PCS) + [args[0]],
954
  )(None),
955
+ outputs=update_all_outputs + [z],
956
  )
957
 
958
+ def update_z(z, s, i):
959
  z[i] = s
960
+ return z
961
 
962
  for i, slider in enumerate(sliders):
963
  slider.input(
964
+ lambda *args, i=i: chain_functions(
965
+ lambda args: update_z(args[0], args[1], i),
966
+ lambda z: (z, vec2fx(z2x(z))),
967
+ lambda args: [args[0]]
968
+ + update_fx(args[1])
969
+ + update_plots(args[1])
970
+ + [model2json(args[1])],
971
+ )(args),
972
+ inputs=[z, slider],
973
+ outputs=[z] + update_fx_outputs + update_plots_outputs + [json_output],
974
  )
975
  extra_slider.input(
976
  lambda *xs: chain_functions(
977
+ lambda args: update_z(args[0], args[1], args[2]),
978
+ lambda z: (z, vec2fx(z2x(z))),
979
+ lambda args: [args[0]]
980
+ + update_fx(args[1])
981
+ + update_plots(args[1])
982
+ + [model2json(args[1])],
983
  )(xs),
984
+ inputs=[z, extra_slider, extra_pc_dropdown],
985
+ outputs=[z] + update_fx_outputs + update_plots_outputs + [json_output],
986
  )
987
 
988
  extra_pc_dropdown.input(
989
+ lambda z, i: z[i - 1].item(),
990
+ inputs=[z, extra_pc_dropdown],
991
  outputs=extra_slider,
992
  )
993