yoyolicoris commited on
Commit
e8b53cc
·
1 Parent(s): 523c68d

feat: enhance chain_functions for improved function composition and readability

Browse files
Files changed (1) hide show
  1. app.py +101 -103
app.py CHANGED
@@ -7,15 +7,23 @@ import json
7
  import pyloudnorm as pyln
8
  from hydra.utils import instantiate
9
  from soxr import resample
10
- from functools import partial
11
  from torchcomp import coef2ms, ms2coef
12
  from copy import deepcopy
13
 
14
- from modules.utils import chain_functions, vec2statedict, get_chunks
15
  from modules.fx import clip_delay_eq_Q
16
  from plot_utils import get_log_mags_from_eq
17
 
18
 
 
 
 
 
 
 
 
 
19
  title_md = "# Vocal Effects Generator"
20
  description_md = """
21
  This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025.
@@ -211,11 +219,11 @@ def plot_eq(fx):
211
  def plot_comp(fx):
212
  fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
213
  comp = fx[6]
214
- cmp_th = fx[6].params.cmp_th.item()
215
- exp_th = fx[6].params.exp_th.item()
216
- cmp_ratio = fx[6].params.cmp_ratio.item()
217
- exp_ratio = fx[6].params.exp_ratio.item()
218
- make_up = fx[6].params.make_up.item()
219
  # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
220
 
221
  comp_in = np.linspace(-80, 0, 100)
@@ -246,13 +254,13 @@ def plot_comp(fx):
246
  def plot_delay(fx):
247
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
248
  delay = fx[7].effects[0]
249
- w, eq_log_mags = get_log_mags_from_eq([fx[7].effects[0].eq])
250
- log_gain = fx[7].effects[0].params.gain.log10().item() * 20
251
- d = fx[7].effects[0].params.delay.item() / 1000
252
  log_mag = sum(eq_log_mags)
253
  ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
254
 
255
- log_feedback = fx[7].effects[0].params.feedback.log10().item() * 20
256
  for i in range(1, 10):
257
  feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
258
  ax.plot(
@@ -631,11 +639,11 @@ with gr.Blocks() as demo:
631
  ]
632
 
633
  delay_sliders = [delay_time, feedback, delay_lp_freq, delay_gain, odd_pan, even_pan]
634
- delay_update_funcs = [update_param] * 3 + [
635
- lambda m, a, v: update_param(m, a, 10 ** (v / 20)),
636
- lambda m, a, v: update_param(m, a, (v + 100) / 200),
637
- lambda m, a, v: update_param(m, a, (v + 100) / 200),
638
- ]
639
  delay_attr_names = [
640
  "params.delay",
641
  "params.feedback",
@@ -669,20 +677,22 @@ with gr.Blocks() as demo:
669
 
670
  for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names):
671
  s.input(
672
- lambda *args, idx=idx, attr_name=attr_name: chain_functions( # chain_functions(
673
- lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
674
- lambda args: (
675
- update_param(args[0][idx].params, attr_name, args[1]),
676
- args[0],
677
- args[2],
678
  ),
679
- lambda args: (fx2x(args[1]), *args[1:]),
680
- lambda args: [x2z(args[0]), *args],
681
- lambda args: args[:2]
682
- + [model2json(args[2]), plot_eq(args[2])]
683
- + update_pc(args[0], args[3]),
684
- )(
685
- args
 
 
 
686
  ),
687
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
688
  outputs=[z, fx_params, json_output, peq_plot] + update_pc_outputs,
@@ -690,19 +700,23 @@ with gr.Blocks() as demo:
690
 
691
  for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names):
692
  s.input(
693
- lambda *args, attr_name=attr_name, f=f: chain_functions(
694
- lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
695
- lambda args: (
696
- f(args[0][6].params, attr_name, args[1]),
697
- args[0],
698
- args[2],
 
 
 
 
699
  ),
700
- lambda args: (fx2x(args[1]), *args[1:]),
701
- lambda args: [x2z(args[0]), *args],
702
- lambda args: args[:2]
703
- + [model2json(args[2]), plot_comp(args[2])]
704
- + update_pc(args[0], args[3]),
705
- )(args),
706
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
707
  outputs=[z, fx_params, json_output, comp_plot] + update_pc_outputs,
708
  )
@@ -711,24 +725,25 @@ with gr.Blocks() as demo:
711
  delay_update_funcs, delay_sliders, delay_attr_names, delay_update_plot_flag
712
  ):
713
  s.input(
714
- lambda *args, f=f, attr_name=attr_name, update_plot=update_plot: chain_functions(
715
- lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]),
716
- lambda args: (
717
- # f(args[0][7].effects[0], attr_name, args[1]),
718
- f(*get_last_attribute(args[0][7].effects[0], attr_name), args[1]),
719
- args[0],
720
- args[2],
 
 
 
721
  ),
722
- lambda args: (fx2x(args[1]), *args[1:]),
723
- lambda args: [x2z(args[0]), *args],
724
- lambda args: (
725
- args[:2]
726
- + [model2json(args[2])]
727
- + ([plot_delay(args[2])] if update_plot else [])
728
- + update_pc(args[0], args[3])
729
  ),
730
- )(
731
- args
732
  ),
733
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
734
  outputs=[z, fx_params]
@@ -738,18 +753,10 @@ with gr.Blocks() as demo:
738
  )
739
 
740
  render_button.click(
741
- # lambda *args: (
742
- # lambda x: (
743
- # x,
744
- # model2json(),
745
- # )
746
- # )(inference(*args)),
747
- # inference,
748
- # lambda audio, x: inference(audio, vec2fx(x)),
749
- lambda audio, *args: chain_functions(
750
- lambda args: assign_fx_params(vec2fx(args[0]), *args[1:]),
751
- partial(inference, audio),
752
- )(args),
753
  inputs=[
754
  audio_input,
755
  fx_params,
@@ -841,22 +848,21 @@ with gr.Blocks() as demo:
841
  random_button.click(
842
  chain_functions(
843
  lambda i: (torch.randn_like(mean).clip(SLIDER_MIN, SLIDER_MAX), i),
844
- lambda args: (args[0], z2x(args[0]), args[1]),
845
- lambda args: [args[0], args[1], vec2fx(args[1]), args[2]],
846
- lambda args: update_all(args[0], args[2], args[3]) + args[:2],
847
  ),
848
  inputs=extra_pc_dropdown,
849
- outputs=update_all_outputs + [z, fx_params],
850
  )
851
  reset_button.click(
852
- # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
853
- lambda: chain_functions(
854
- lambda _: torch.zeros_like(mean),
855
  lambda z: (z, z2x(z)),
856
- lambda args: [*args[:2], vec2fx(args[1])],
857
- lambda args: update_all(args[0], args[2], NUMBER_OF_PCS) + args[:2],
858
- )(None),
859
- outputs=update_all_outputs + [z, fx_params],
860
  )
861
 
862
  def update_z(z, s, i):
@@ -865,36 +871,28 @@ with gr.Blocks() as demo:
865
 
866
  for i, slider in enumerate(sliders):
867
  slider.input(
868
- lambda *args, i=i: chain_functions(
869
- lambda args: update_z(args[0], args[1], i),
870
  lambda z: (z, z2x(z)),
871
- lambda args: [args[0], args[1], vec2fx(args[1])],
872
- lambda args: args[:2]
873
- + update_fx(args[2])
874
- + update_plots(args[2])
875
- + [model2json(args[2])],
876
- )(args),
877
  inputs=[z, slider],
878
- outputs=[z, fx_params]
879
  + update_fx_outputs
880
- + update_plots_outputs
881
- + [json_output],
882
  )
883
  extra_slider.input(
884
- lambda *xs: chain_functions(
885
- lambda args: update_z(args[0], args[1], args[2]),
886
  lambda z: (z, z2x(z)),
887
- lambda args: [args[0], args[1], vec2fx(args[1])],
888
- lambda args: args[:2]
889
- + update_fx(args[2])
890
- + update_plots(args[2])
891
- + [model2json(args[2])],
892
- )(xs),
893
  inputs=[z, extra_slider, extra_pc_dropdown],
894
- outputs=[z, fx_params]
895
- + update_fx_outputs
896
- + update_plots_outputs
897
- + [json_output],
898
  )
899
 
900
  extra_pc_dropdown.input(
 
7
  import pyloudnorm as pyln
8
  from hydra.utils import instantiate
9
  from soxr import resample
10
+ from functools import partial, reduce
11
  from torchcomp import coef2ms, ms2coef
12
  from copy import deepcopy
13
 
14
+ from modules.utils import vec2statedict, get_chunks
15
  from modules.fx import clip_delay_eq_Q
16
  from plot_utils import get_log_mags_from_eq
17
 
18
 
19
+ def chain_functions(*functions):
20
+ return lambda *initial_args: reduce(
21
+ lambda xs, f: f(*xs) if isinstance(xs, tuple) else f(xs),
22
+ functions,
23
+ initial_args,
24
+ )
25
+
26
+
27
  title_md = "# Vocal Effects Generator"
28
  description_md = """
29
  This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025.
 
219
  def plot_comp(fx):
220
  fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
221
  comp = fx[6]
222
+ cmp_th = comp.params.cmp_th.item()
223
+ exp_th = comp.params.exp_th.item()
224
+ cmp_ratio = comp.params.cmp_ratio.item()
225
+ exp_ratio = comp.params.exp_ratio.item()
226
+ make_up = comp.params.make_up.item()
227
  # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
228
 
229
  comp_in = np.linspace(-80, 0, 100)
 
254
  def plot_delay(fx):
255
  fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
256
  delay = fx[7].effects[0]
257
+ w, eq_log_mags = get_log_mags_from_eq([delay.eq])
258
+ log_gain = delay.params.gain.log10().item() * 20
259
+ d = delay.params.delay.item() / 1000
260
  log_mag = sum(eq_log_mags)
261
  ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
262
 
263
+ log_feedback = delay.params.feedback.log10().item() * 20
264
  for i in range(1, 10):
265
  feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
266
  ax.plot(
 
639
  ]
640
 
641
  delay_sliders = [delay_time, feedback, delay_lp_freq, delay_gain, odd_pan, even_pan]
642
+ delay_update_funcs = (
643
+ [update_param] * 3
644
+ + [lambda m, a, v: update_param(m, a, 10 ** (v / 20))]
645
+ + [lambda m, a, v: update_param(m, a, (v + 100) / 200)] * 2
646
+ )
647
  delay_attr_names = [
648
  "params.delay",
649
  "params.feedback",
 
677
 
678
  for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names):
679
  s.input(
680
+ chain_functions(
681
+ lambda x, s, extra_pc_idx, *all_s: (
682
+ assign_fx_params(vec2fx(x), *all_s),
683
+ s,
684
+ extra_pc_idx,
 
685
  ),
686
+ lambda fx, s, extra_pc_idx, idx=idx, attr_name=attr_name: (
687
+ update_param(fx[idx].params, attr_name, s),
688
+ fx,
689
+ extra_pc_idx,
690
+ ),
691
+ lambda _, fx, extra_pc_idx: (fx2x(fx), fx, extra_pc_idx),
692
+ lambda x, fx, extra_pc_idx: (x2z(x), x, fx, extra_pc_idx),
693
+ lambda z, x, fx, extra_pc_idx: [z, x]
694
+ + [model2json(fx), plot_eq(fx)]
695
+ + update_pc(z, extra_pc_idx),
696
  ),
697
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
698
  outputs=[z, fx_params, json_output, peq_plot] + update_pc_outputs,
 
700
 
701
  for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names):
702
  s.input(
703
+ chain_functions(
704
+ lambda x, s, e_pc_i, *all_s: (
705
+ assign_fx_params(vec2fx(x), *all_s),
706
+ s,
707
+ e_pc_i,
708
+ ),
709
+ lambda fx, s, e_pc_i, attr_name=attr_name, f=f: (
710
+ f(fx[6].params, attr_name, s),
711
+ fx,
712
+ e_pc_i,
713
  ),
714
+ lambda _, fx, e_pc_i: (fx2x(fx), fx, e_pc_i),
715
+ lambda x, fx, e_pc_i: (x2z(x), x, fx, e_pc_i),
716
+ lambda z, x, fx, e_pc_i: [z, x]
717
+ + [model2json(fx), plot_comp(fx)]
718
+ + update_pc(z, e_pc_i),
719
+ ),
720
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
721
  outputs=[z, fx_params, json_output, comp_plot] + update_pc_outputs,
722
  )
 
725
  delay_update_funcs, delay_sliders, delay_attr_names, delay_update_plot_flag
726
  ):
727
  s.input(
728
+ chain_functions(
729
+ lambda x, s, e_pc_i, *all_s: (
730
+ assign_fx_params(vec2fx(x), *all_s),
731
+ s,
732
+ e_pc_i,
733
+ ),
734
+ lambda fx, s, e_pc_i, f=f, attr_name=attr_name: (
735
+ f(*get_last_attribute(fx[7].effects[0], attr_name), s),
736
+ fx,
737
+ e_pc_i,
738
  ),
739
+ lambda _, fx, e_pc_i: (fx2x(fx), fx, e_pc_i),
740
+ lambda x, fx, e_pc_i: (x2z(x), x, fx, e_pc_i),
741
+ lambda z, x, fx, e_pc_i, update_plot=update_plot: (
742
+ [z, x]
743
+ + [model2json(fx)]
744
+ + ([plot_delay(fx)] if update_plot else [])
745
+ + update_pc(z, e_pc_i)
746
  ),
 
 
747
  ),
748
  inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders,
749
  outputs=[z, fx_params]
 
753
  )
754
 
755
  render_button.click(
756
+ chain_functions(
757
+ lambda audio, x, *all_s: (audio, assign_fx_params(vec2fx(x), *all_s)),
758
+ inference,
759
+ ),
 
 
 
 
 
 
 
 
760
  inputs=[
761
  audio_input,
762
  fx_params,
 
848
  random_button.click(
849
  chain_functions(
850
  lambda i: (torch.randn_like(mean).clip(SLIDER_MIN, SLIDER_MAX), i),
851
+ lambda z, i: (z, z2x(z), i),
852
+ lambda z, x, i: (z, x, vec2fx(x), i),
853
+ lambda z, x, fx, i: [z, x] + update_all(z, fx, i),
854
  ),
855
  inputs=extra_pc_dropdown,
856
+ outputs=[z, fx_params] + update_all_outputs,
857
  )
858
  reset_button.click(
859
+ chain_functions(
860
+ lambda: torch.zeros_like(mean),
 
861
  lambda z: (z, z2x(z)),
862
+ lambda z, x: (z, x, vec2fx(x)),
863
+ lambda z, x, fx: [z, x] + update_all(z, fx, NUMBER_OF_PCS),
864
+ ),
865
+ outputs=[z, fx_params] + update_all_outputs,
866
  )
867
 
868
  def update_z(z, s, i):
 
871
 
872
  for i, slider in enumerate(sliders):
873
  slider.input(
874
+ chain_functions(
875
+ lambda z, s, i=i: update_z(z, s, i),
876
  lambda z: (z, z2x(z)),
877
+ lambda z, x: (z, x, vec2fx(x)),
878
+ lambda z, x, fx: [z, x, model2json(fx)]
879
+ + update_fx(fx)
880
+ + update_plots(fx),
881
+ ),
 
882
  inputs=[z, slider],
883
+ outputs=[z, fx_params, json_output]
884
  + update_fx_outputs
885
+ + update_plots_outputs,
 
886
  )
887
  extra_slider.input(
888
+ chain_functions(
889
+ lambda z, s, i: update_z(z, s, i - 1),
890
  lambda z: (z, z2x(z)),
891
+ lambda z, x: (z, x, vec2fx(x)),
892
+ lambda z, x, fx: [z, x, model2json(fx)] + update_fx(fx) + update_plots(fx),
893
+ ),
 
 
 
894
  inputs=[z, extra_slider, extra_pc_dropdown],
895
+ outputs=[z, fx_params, json_output] + update_fx_outputs + update_plots_outputs,
 
 
 
896
  )
897
 
898
  extra_pc_dropdown.input(