yoyolicoris commited on
Commit
6bd893f
·
1 Parent(s): 644e3c2

feat: separate direct and wet audio outputs and enable compressor control

Browse files
Files changed (1) hide show
  1. app.py +134 -13
app.py CHANGED
@@ -8,6 +8,7 @@ import pyloudnorm as pyln
8
  from hydra.utils import instantiate
9
  from soxr import resample
10
  from functools import partial
 
11
 
12
  from modules.utils import chain_functions, vec2statedict, get_chunks
13
  from modules.fx import clip_delay_eq_Q
@@ -109,6 +110,7 @@ def z2fx():
109
 
110
  @torch.no_grad()
111
  def fx2z():
 
112
  state_dict = fx.state_dict()
113
  flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
114
  x = flattened[feature_mask]
@@ -133,10 +135,24 @@ def inference(audio):
133
  if y.shape[1] != 1:
134
  y = y.mean(dim=1, keepdim=True)
135
 
136
- rendered = fx(y).squeeze(0).T.numpy()
 
 
 
 
137
  if np.max(np.abs(rendered)) > 1:
138
- rendered = rendered / np.max(np.abs(rendered))
139
- return (44100, (rendered * 32768).astype(np.int16))
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  def get_important_pcs(n=10, **kwargs):
@@ -294,12 +310,17 @@ def plot_t60():
294
 
295
 
296
  @torch.no_grad()
297
- def upatePEQ(eq, attr_name, value):
298
- match type(getattr(eq.params, attr_name)):
299
  case torch.nn.Parameter:
300
- getattr(eq.params, attr_name).data.copy_(value)
301
  case _:
302
- setattr(eq.params, attr_name, torch.tensor(value))
 
 
 
 
 
303
 
304
 
305
  with gr.Blocks() as demo:
@@ -388,11 +409,15 @@ with gr.Blocks() as demo:
388
  audio_output = gr.Audio(
389
  type="numpy", label="Output Audio", interactive=False, loop=True
390
  )
391
-
392
- _ = gr.Markdown("## Parametric EQ")
393
- peq_plot = gr.Plot(
394
- plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
 
395
  )
 
 
 
396
  with gr.Row():
397
  with gr.Column(min_width=160):
398
  _ = gr.Markdown("High Pass")
@@ -514,7 +539,63 @@ with gr.Blocks() as demo:
514
  label="Q",
515
  )
516
 
517
- comp_plot = gr.Plot(plot_comp(), label="Compressor Curve", elem_id="comp-plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  delay_plot = gr.Plot(
519
  plot_delay(), label="Delay Frequency Response", elem_id="delay-plot"
520
  )
@@ -558,7 +639,7 @@ with gr.Blocks() as demo:
558
  ):
559
  s.input(
560
  lambda *args, eq=eq, attr_name=attr_name: chain_functions( # chain_functions(
561
- lambda args: (upatePEQ(eq, attr_name, args[0]), args[1]),
562
  lambda args: (fx2z(), args[1]),
563
  lambda args: args[1],
564
  lambda i: update_pc(i) + [model2json(), plot_eq()],
@@ -569,6 +650,30 @@ with gr.Blocks() as demo:
569
  outputs=update_pc_outputs + [json_output, peq_plot],
570
  )
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  render_button.click(
573
  # lambda *args: (
574
  # lambda x: (
@@ -582,6 +687,8 @@ with gr.Blocks() as demo:
582
  ],
583
  outputs=[
584
  audio_output,
 
 
585
  ],
586
  )
587
 
@@ -600,6 +707,13 @@ with gr.Blocks() as demo:
600
  lp.params.Q.item(),
601
  hp.params.freq.item(),
602
  hp.params.Q.item(),
 
 
 
 
 
 
 
603
  ]
604
  update_fx_outputs = [
605
  pk1_freq,
@@ -616,6 +730,13 @@ with gr.Blocks() as demo:
616
  lp_q,
617
  hp_freq,
618
  hp_q,
 
 
 
 
 
 
 
619
  ]
620
  update_plots = lambda: [
621
  plot_eq(),
 
8
  from hydra.utils import instantiate
9
  from soxr import resample
10
  from functools import partial
11
+ from torchcomp import coef2ms, ms2coef
12
 
13
  from modules.utils import chain_functions, vec2statedict, get_chunks
14
  from modules.fx import clip_delay_eq_Q
 
110
 
111
  @torch.no_grad()
112
  def fx2z():
113
+ plt.close("all")
114
  state_dict = fx.state_dict()
115
  flattened = torch.cat([state_dict[k].flatten() for k in param_keys])
116
  x = flattened[feature_mask]
 
135
  if y.shape[1] != 1:
136
  y = y.mean(dim=1, keepdim=True)
137
 
138
+ direct, wet = fx(y)
139
+ direct = direct.squeeze(0).T.numpy()
140
+ wet = wet.squeeze(0).T.numpy()
141
+ rendered = direct + wet
142
+ # rendered = fx(y).squeeze(0).T.numpy()
143
  if np.max(np.abs(rendered)) > 1:
144
+ scaler = np.max(np.abs(rendered))
145
+ rendered = rendered / scaler
146
+ direct = direct / scaler
147
+ wet = wet / scaler
148
+ return (
149
+ (44100, (rendered * 32768).astype(np.int16)),
150
+ (44100, (direct * 32768).astype(np.int16)),
151
+ (
152
+ 44100,
153
+ (wet * 32768).astype(np.int16),
154
+ ),
155
+ )
156
 
157
 
158
  def get_important_pcs(n=10, **kwargs):
 
310
 
311
 
312
  @torch.no_grad()
313
+ def update_param(m, attr_name, value):
314
+ match type(getattr(m.params, attr_name)):
315
  case torch.nn.Parameter:
316
+ getattr(m.params, attr_name).data.copy_(value)
317
  case _:
318
+ setattr(m.params, attr_name, torch.tensor(value))
319
+
320
+
321
+ @torch.no_grad()
322
+ def update_atrt(comp, attr_name, value):
323
+ setattr(comp.params, attr_name, ms2coef(torch.tensor(value), 44100))
324
 
325
 
326
  with gr.Blocks() as demo:
 
409
  audio_output = gr.Audio(
410
  type="numpy", label="Output Audio", interactive=False, loop=True
411
  )
412
+ direct_output = gr.Audio(
413
+ type="numpy", label="Direct Audio", interactive=False, loop=True
414
+ )
415
+ wet_output = gr.Audio(
416
+ type="numpy", label="Wet Audio", interactive=False, loop=True
417
  )
418
+
419
+ _ = gr.Markdown("## Parametric EQ")
420
+ peq_plot = gr.Plot(plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot")
421
  with gr.Row():
422
  with gr.Column(min_width=160):
423
  _ = gr.Markdown("High Pass")
 
539
  label="Q",
540
  )
541
 
542
+ _ = gr.Markdown("## Compressor and Expander")
543
+ with gr.Row():
544
+ with gr.Column():
545
+ comp = fx[6]
546
+ cmp_th = gr.Slider(
547
+ minimum=-60,
548
+ maximum=0,
549
+ value=comp.params.cmp_th.item(),
550
+ interactive=True,
551
+ label="Comp. Threshold (dB)",
552
+ )
553
+ cmp_ratio = gr.Slider(
554
+ minimum=1,
555
+ maximum=20,
556
+ value=comp.params.cmp_ratio.item(),
557
+ interactive=True,
558
+ label="Comp. Ratio",
559
+ )
560
+ make_up = gr.Slider(
561
+ minimum=-12,
562
+ maximum=12,
563
+ value=comp.params.make_up.item(),
564
+ interactive=True,
565
+ label="Make Up (dB)",
566
+ )
567
+ attack_time = gr.Slider(
568
+ minimum=0.1,
569
+ maximum=100,
570
+ value=coef2ms(comp.params.at, 44100).item(),
571
+ interactive=True,
572
+ label="Attack Time (ms)",
573
+ )
574
+ release_time = gr.Slider(
575
+ minimum=50,
576
+ maximum=1000,
577
+ value=coef2ms(comp.params.rt, 44100).item(),
578
+ interactive=True,
579
+ label="Release Time (ms)",
580
+ )
581
+ exp_ratio = gr.Slider(
582
+ minimum=0,
583
+ maximum=1,
584
+ value=comp.params.exp_ratio.item(),
585
+ interactive=True,
586
+ label="Exp. Ratio",
587
+ )
588
+ exp_th = gr.Slider(
589
+ minimum=-80,
590
+ maximum=0,
591
+ value=comp.params.exp_th.item(),
592
+ interactive=True,
593
+ label="Exp. Threshold (dB)",
594
+ )
595
+ with gr.Column():
596
+ comp_plot = gr.Plot(
597
+ plot_comp(), label="Compressor Curve", elem_id="comp-plot"
598
+ )
599
  delay_plot = gr.Plot(
600
  plot_delay(), label="Delay Frequency Response", elem_id="delay-plot"
601
  )
 
639
  ):
640
  s.input(
641
  lambda *args, eq=eq, attr_name=attr_name: chain_functions( # chain_functions(
642
+ lambda args: (update_param(eq, attr_name, args[0]), args[1]),
643
  lambda args: (fx2z(), args[1]),
644
  lambda args: args[1],
645
  lambda i: update_pc(i) + [model2json(), plot_eq()],
 
650
  outputs=update_pc_outputs + [json_output, peq_plot],
651
  )
652
 
653
+ for f, s, attr_name in zip(
654
+ [update_param] * 5 + [update_atrt] * 2,
655
+ [
656
+ cmp_th,
657
+ cmp_ratio,
658
+ make_up,
659
+ exp_ratio,
660
+ exp_th,
661
+ attack_time,
662
+ release_time,
663
+ ],
664
+ ["cmp_th", "cmp_ratio", "make_up", "exp_ratio", "exp_th", "at", "rt"],
665
+ ):
666
+ s.input(
667
+ lambda *args, attr_name=attr_name, f=f: chain_functions(
668
+ lambda args: (f(comp, attr_name, args[0]), args[1]),
669
+ lambda args: (fx2z(), args[1]),
670
+ lambda args: args[1],
671
+ lambda i: update_pc(i) + [model2json(), plot_comp()],
672
+ )(args),
673
+ inputs=[s, extra_pc_dropdown],
674
+ outputs=update_pc_outputs + [json_output, comp_plot],
675
+ )
676
+
677
  render_button.click(
678
  # lambda *args: (
679
  # lambda x: (
 
687
  ],
688
  outputs=[
689
  audio_output,
690
+ direct_output,
691
+ wet_output,
692
  ],
693
  )
694
 
 
707
  lp.params.Q.item(),
708
  hp.params.freq.item(),
709
  hp.params.Q.item(),
710
+ comp.params.cmp_th.item(),
711
+ comp.params.cmp_ratio.item(),
712
+ comp.params.make_up.item(),
713
+ comp.params.exp_th.item(),
714
+ comp.params.exp_ratio.item(),
715
+ coef2ms(comp.params.at, 44100).item(),
716
+ coef2ms(comp.params.rt, 44100).item(),
717
  ]
718
  update_fx_outputs = [
719
  pk1_freq,
 
730
  lp_q,
731
  hp_freq,
732
  hp_q,
733
+ cmp_th,
734
+ cmp_ratio,
735
+ make_up,
736
+ exp_th,
737
+ exp_ratio,
738
+ attack_time,
739
+ release_time,
740
  ]
741
  update_plots = lambda: [
742
  plot_eq(),