yoyolicoris commited on
Commit
1291f86
·
1 Parent(s): 96bd919

feat: add plotting functions for compressor, delay, reverb, and T60 responses

Browse files
Files changed (1) hide show
  1. app.py +201 -14
app.py CHANGED
@@ -103,6 +103,7 @@ def z2fx():
103
  x = U @ z + mean
104
  # print(z)
105
  fx.load_state_dict(vec2dict(x), strict=False)
 
106
  return
107
 
108
 
@@ -136,8 +137,6 @@ def inference(audio):
136
  if y.shape[1] != 1:
137
  y = y.mean(dim=1, keepdim=True)
138
 
139
- fx.apply(partial(clip_delay_eq_Q, Q=0.707))
140
-
141
  rendered = fx(y).squeeze(0).T.numpy()
142
  if np.max(np.abs(rendered)) > 1:
143
  rendered = rendered / np.max(np.abs(rendered))
@@ -177,7 +176,7 @@ def model2json():
177
 
178
  @torch.no_grad()
179
  def plot_eq():
180
- fig, ax = plt.subplots(figsize=(8, 4))
181
  w, eq_log_mags = get_log_mags_from_eq(fx[:6])
182
  ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
183
  for i, eq_log_mag in enumerate(eq_log_mags):
@@ -192,6 +191,109 @@ def plot_eq():
192
  return fig
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Blocks() as demo:
196
  gr.Markdown(
197
  title_md,
@@ -249,19 +351,49 @@ with gr.Blocks() as demo:
249
  peq_plot = gr.Plot(
250
  plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
251
  )
 
 
 
 
 
 
 
 
 
 
252
 
253
  with gr.Row():
254
- json_output = gr.JSON(label="Effect Settings", max_height=800, open=True)
 
 
255
 
256
  render_button.click(
257
- lambda *args: (lambda x: (x, model2json(), plot_eq()))(inference(*args)),
 
 
 
 
 
 
 
 
 
 
258
  inputs=[
259
  audio_input,
260
  # random_rest_checkbox,
261
  ]
262
  # + sliders,
263
  ,
264
- outputs=[audio_output, json_output, peq_plot],
 
 
 
 
 
 
 
 
265
  )
266
 
267
  random_button.click(
@@ -279,20 +411,54 @@ with gr.Blocks() as demo:
279
  lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i),
280
  lambda args: args + (z2fx(),),
281
  lambda args: args[0][:NUMBER_OF_PCS].tolist()
282
- + [args[0][args[1] - 1].item(), plot_eq()],
 
 
 
 
 
 
 
 
283
  ),
284
  inputs=extra_pc_dropdown,
285
- outputs=sliders + [extra_slider, peq_plot],
 
 
 
 
 
 
 
 
 
286
  )
287
  reset_button.click(
288
  # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
289
  lambda: chain_functions(
290
  lambda _: z.zero_(),
291
  lambda _: z2fx(),
292
- lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)] + [plot_eq()],
 
 
 
 
 
 
 
 
293
  )(None),
294
  # inputs=sliders + [extra_slider],
295
- outputs=sliders + [extra_slider, peq_plot],
 
 
 
 
 
 
 
 
 
296
  )
297
 
298
  def update_z(s, i):
@@ -304,19 +470,40 @@ with gr.Blocks() as demo:
304
  chain_functions(
305
  partial(update_z, i=i),
306
  lambda _: z2fx(),
307
- lambda _: plot_eq(),
 
 
 
 
 
 
 
308
  ),
309
  inputs=slider,
310
- outputs=peq_plot,
 
 
 
 
 
 
 
311
  )
312
  extra_slider.input(
313
  lambda *xs: chain_functions(
314
  lambda args: update_z(args[0], args[1] - 1),
315
  lambda _: z2fx(),
316
- lambda _: plot_eq(),
 
 
 
 
 
 
 
317
  )(xs),
318
  inputs=[extra_slider, extra_pc_dropdown],
319
- outputs=peq_plot,
320
  )
321
 
322
  extra_pc_dropdown.input(
 
103
  x = U @ z + mean
104
  # print(z)
105
  fx.load_state_dict(vec2dict(x), strict=False)
106
+ fx.apply(partial(clip_delay_eq_Q, Q=0.707))
107
  return
108
 
109
 
 
137
  if y.shape[1] != 1:
138
  y = y.mean(dim=1, keepdim=True)
139
 
 
 
140
  rendered = fx(y).squeeze(0).T.numpy()
141
  if np.max(np.abs(rendered)) > 1:
142
  rendered = rendered / np.max(np.abs(rendered))
 
176
 
177
  @torch.no_grad()
178
  def plot_eq():
179
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
180
  w, eq_log_mags = get_log_mags_from_eq(fx[:6])
181
  ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
182
  for i, eq_log_mag in enumerate(eq_log_mags):
 
191
  return fig
192
 
193
 
194
+ @torch.no_grad()
195
+ def plot_comp():
196
+ fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True)
197
+ comp = fx[6]
198
+ cmp_th = comp.params.cmp_th.item()
199
+ exp_th = comp.params.exp_th.item()
200
+ cmp_ratio = comp.params.cmp_ratio.item()
201
+ exp_ratio = comp.params.exp_ratio.item()
202
+ make_up = comp.params.make_up.item()
203
+ # print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up)
204
+
205
+ comp_in = np.linspace(-80, 0, 100)
206
+ comp_curve = np.where(
207
+ comp_in > cmp_th,
208
+ comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio,
209
+ comp_in,
210
+ )
211
+ comp_out = (
212
+ np.where(
213
+ comp_curve < exp_th,
214
+ comp_curve - (exp_th - comp_curve) / exp_ratio,
215
+ comp_curve,
216
+ )
217
+ + make_up
218
+ )
219
+ ax.plot(comp_in, comp_out, c="black", linestyle="-")
220
+ ax.plot(comp_in, comp_in, c="r", alpha=0.5)
221
+ ax.set_xlabel("Input Level (dB)")
222
+ ax.set_ylabel("Output Level (dB)")
223
+ ax.set_xlim(-80, 0)
224
+ ax.set_ylim(-80, 0)
225
+ ax.grid()
226
+ return fig
227
+
228
+
229
+ @torch.no_grad()
230
+ def plot_delay():
231
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
232
+ delay = fx[7].effects[0]
233
+ w, eq_log_mags = get_log_mags_from_eq([delay.eq])
234
+ log_gain = delay.params.gain.log10().item() * 20
235
+ d = delay.params.delay.item() / 1000
236
+ log_mag = sum(eq_log_mags)
237
+ ax.plot(w, log_mag + log_gain, color="black", linestyle="-")
238
+
239
+ log_feedback = delay.params.feedback.log10().item() * 20
240
+ for i in range(1, 10):
241
+ feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain
242
+ ax.plot(
243
+ w,
244
+ feedback_log_mag,
245
+ c="black",
246
+ alpha=max(0, (10 - i * d * 4) / 10),
247
+ linestyle="-",
248
+ )
249
+
250
+ ax.set_xscale("log")
251
+ ax.set_xlim(20, 20000)
252
+ ax.set_ylim(-80, 0)
253
+ ax.set_xlabel("Frequency (Hz)")
254
+ ax.set_ylabel("Magnitude (dB)")
255
+ ax.grid()
256
+ return fig
257
+
258
+
259
+ @torch.no_grad()
260
+ def plot_reverb():
261
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
262
+ fdn = fx[7].effects[1]
263
+ w, eq_log_mags = get_log_mags_from_eq(fdn.eq)
264
+
265
+ bc = fdn.params.c.norm() * fdn.params.b.norm()
266
+ log_bc = torch.log10(bc).item() * 20
267
+ eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags]
268
+ ax.plot(w, sum(eq_log_mags), color="black", linestyle="-")
269
+
270
+ ax.set_xlabel("Frequency (Hz)")
271
+ ax.set_ylabel("Magnitude (dB)")
272
+ ax.set_xlim(20, 20000)
273
+ ax.set_ylim(-40, 6)
274
+ ax.set_xscale("log")
275
+ ax.grid()
276
+ return fig
277
+
278
+
279
+ @torch.no_grad()
280
+ def plot_t60():
281
+ fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
282
+ fdn = fx[7].effects[1]
283
+ gamma = fdn.params.gamma.squeeze().numpy()
284
+ delays = fdn.delays.numpy()
285
+ w = np.linspace(0, 22050, gamma.size)
286
+ t60 = -60 / (20 * np.log10(gamma) / np.min(delays)) / 44100
287
+ ax.plot(w, t60, color="black", linestyle="-")
288
+ ax.set_xlabel("Frequency (Hz)")
289
+ ax.set_ylabel("T60 (s)")
290
+ ax.set_xlim(20, 20000)
291
+ ax.set_ylim(0, 9)
292
+ ax.set_xscale("log")
293
+ ax.grid()
294
+ return fig
295
+
296
+
297
  with gr.Blocks() as demo:
298
  gr.Markdown(
299
  title_md,
 
351
  peq_plot = gr.Plot(
352
  plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot"
353
  )
354
+ comp_plot = gr.Plot(
355
+ plot_comp(), label="Compressor Curve", elem_id="comp-plot"
356
+ )
357
+ delay_plot = gr.Plot(
358
+ plot_delay(), label="Delay Frequency Response", elem_id="delay-plot"
359
+ )
360
+ reverb_plot = gr.Plot(
361
+ plot_reverb(), label="Reverb Tone Correction PEQ", elem_id="reverb-plot"
362
+ )
363
+ t60_plot = gr.Plot(plot_t60(), label="Reverb T60", elem_id="t60-plot")
364
 
365
  with gr.Row():
366
+ json_output = gr.JSON(
367
+ model2json(), label="Effect Settings", max_height=800, open=True
368
+ )
369
 
370
  render_button.click(
371
+ lambda *args: (
372
+ lambda x: (
373
+ x,
374
+ model2json(),
375
+ plot_eq(),
376
+ plot_comp(),
377
+ plot_delay(),
378
+ plot_reverb(),
379
+ plot_t60(),
380
+ )
381
+ )(inference(*args)),
382
  inputs=[
383
  audio_input,
384
  # random_rest_checkbox,
385
  ]
386
  # + sliders,
387
  ,
388
+ outputs=[
389
+ audio_output,
390
+ json_output,
391
+ peq_plot,
392
+ comp_plot,
393
+ delay_plot,
394
+ reverb_plot,
395
+ t60_plot,
396
+ ],
397
  )
398
 
399
  random_button.click(
 
411
  lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i),
412
  lambda args: args + (z2fx(),),
413
  lambda args: args[0][:NUMBER_OF_PCS].tolist()
414
+ + [
415
+ args[0][args[1] - 1].item(),
416
+ model2json(),
417
+ plot_eq(),
418
+ plot_comp(),
419
+ plot_delay(),
420
+ plot_reverb(),
421
+ plot_t60(),
422
+ ],
423
  ),
424
  inputs=extra_pc_dropdown,
425
+ outputs=sliders
426
+ + [
427
+ extra_slider,
428
+ json_output,
429
+ peq_plot,
430
+ comp_plot,
431
+ delay_plot,
432
+ reverb_plot,
433
+ t60_plot,
434
+ ],
435
  )
436
  reset_button.click(
437
  # lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()),
438
  lambda: chain_functions(
439
  lambda _: z.zero_(),
440
  lambda _: z2fx(),
441
+ lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)]
442
+ + [
443
+ model2json(),
444
+ plot_eq(),
445
+ plot_comp(),
446
+ plot_delay(),
447
+ plot_reverb(),
448
+ plot_t60(),
449
+ ],
450
  )(None),
451
  # inputs=sliders + [extra_slider],
452
+ outputs=sliders
453
+ + [
454
+ extra_slider,
455
+ json_output,
456
+ peq_plot,
457
+ comp_plot,
458
+ delay_plot,
459
+ reverb_plot,
460
+ t60_plot,
461
+ ],
462
  )
463
 
464
  def update_z(s, i):
 
470
  chain_functions(
471
  partial(update_z, i=i),
472
  lambda _: z2fx(),
473
+ lambda _: (
474
+ model2json(),
475
+ plot_eq(),
476
+ plot_comp(),
477
+ plot_delay(),
478
+ plot_reverb(),
479
+ plot_t60(),
480
+ ),
481
  ),
482
  inputs=slider,
483
+ outputs=[
484
+ json_output,
485
+ peq_plot,
486
+ comp_plot,
487
+ delay_plot,
488
+ reverb_plot,
489
+ t60_plot,
490
+ ],
491
  )
492
  extra_slider.input(
493
  lambda *xs: chain_functions(
494
  lambda args: update_z(args[0], args[1] - 1),
495
  lambda _: z2fx(),
496
+ lambda _: (
497
+ model2json(),
498
+ plot_eq(),
499
+ plot_comp(),
500
+ plot_delay(),
501
+ plot_reverb(),
502
+ plot_t60(),
503
+ ),
504
  )(xs),
505
  inputs=[extra_slider, extra_pc_dropdown],
506
+ outputs=[json_output, peq_plot, comp_plot, delay_plot, reverb_plot, t60_plot],
507
  )
508
 
509
  extra_pc_dropdown.input(