Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import yaml | |
import json | |
import pyloudnorm as pyln | |
from hydra.utils import instantiate | |
from soxr import resample | |
from functools import partial | |
from modules.utils import chain_functions, vec2statedict, get_chunks | |
from modules.fx import clip_delay_eq_Q | |
from plot_utils import get_log_mags_from_eq | |
title_md = "# Vocal Effects Generator" | |
description_md = """ | |
This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025. | |
In this demo, you can upload a raw vocal audio file (in mono) and apply random effects to make it sound better! | |
The effects consist of series of EQ, compressor, delay, and reverb. | |
The generator is a PCA model derived from 365 vocal effects presets fitted with the same effects chain. | |
This interface allows you to control the principal components (PCs) of the generator, randomise them, and render the audio. | |
To give you some idea, we emperically found that the first PC controls the amount of reverb and the second PC controls the amount of brightness. | |
Note that adding these PCs together does not necessarily mean that their effects are additive in the final audio. | |
We found sometimes the effects of least important PCs are more perceptible. | |
Try to play around with the sliders and buttons and see what you can come up with! | |
Currently only PCs are tweakable, but in the future we will add more controls and visualisation tools. | |
For example: | |
- Directly controlling the parameters of the effects | |
- Visualising the PCA space | |
- Visualising the frequency responses/dynamic curves of the effects | |
""" | |
SLIDER_MAX = 3 | |
SLIDER_MIN = -3 | |
NUMBER_OF_PCS = 10 | |
TEMPERATURE = 0.7 | |
CONFIG_PATH = "presets/rt_config.yaml" | |
PCA_PARAM_FILE = "presets/internal/gaussian.npz" | |
INFO_PATH = "presets/internal/info.json" | |
MASK_PATH = "presets/internal/feature_mask.npy" | |
with open(CONFIG_PATH) as fp: | |
fx_config = yaml.safe_load(fp)["model"] | |
# Global effect | |
fx = instantiate(fx_config) | |
fx.eval() | |
pca_params = np.load(PCA_PARAM_FILE) | |
mean = pca_params["mean"] | |
cov = pca_params["cov"] | |
eigvals, eigvecs = np.linalg.eigh(cov) | |
eigvals = np.flip(eigvals, axis=0)[:75] | |
eigvecs = np.flip(eigvecs, axis=1)[:, :75] | |
U = eigvecs * np.sqrt(eigvals) | |
U = torch.from_numpy(U).float() | |
mean = torch.from_numpy(mean).float() | |
feature_mask = torch.from_numpy(np.load(MASK_PATH)) | |
# Global latent variable | |
z = torch.zeros(75) | |
with open(INFO_PATH) as f: | |
info = json.load(f) | |
param_keys = info["params_keys"] | |
original_shapes = list( | |
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
) | |
*vec2dict_args, _ = get_chunks(param_keys, original_shapes) | |
vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
vec2dict = partial( | |
vec2statedict, | |
**dict( | |
zip( | |
[ | |
"keys", | |
"original_shapes", | |
"selected_chunks", | |
"position", | |
"U_matrix_shape", | |
], | |
vec2dict_args, | |
) | |
), | |
) | |
fx.load_state_dict(vec2dict(mean), strict=False) | |
meter = pyln.Meter(44100) | |
def z2fx(): | |
# close all figures to avoid too many open figures | |
plt.close("all") | |
x = U @ z + mean | |
# print(z) | |
fx.load_state_dict(vec2dict(x), strict=False) | |
return | |
def fx2z(func): | |
def wrapper(*args, **kwargs): | |
ret = func(*args, **kwargs) | |
state_dict = fx.state_dict() | |
flattened = torch.cat([state_dict[k].flatten() for k in param_keys]) | |
x = flattened[feature_mask] | |
z.copy_(U.T @ (x - mean)) | |
return ret | |
return wrapper | |
def inference(audio): | |
sr, y = audio | |
if sr != 44100: | |
y = resample(y, sr, 44100) | |
if y.dtype.kind != "f": | |
y = y / 32768.0 | |
if y.ndim == 1: | |
y = y[:, None] | |
loudness = meter.integrated_loudness(y) | |
y = pyln.normalize.loudness(y, loudness, -18.0) | |
y = torch.from_numpy(y).float().T.unsqueeze(0) | |
if y.shape[1] != 1: | |
y = y.mean(dim=1, keepdim=True) | |
fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
rendered = fx(y).squeeze(0).T.numpy() | |
if np.max(np.abs(rendered)) > 1: | |
rendered = rendered / np.max(np.abs(rendered)) | |
return (44100, (rendered * 32768).astype(np.int16)) | |
def get_important_pcs(n=10, **kwargs): | |
sliders = [ | |
gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs) | |
for i in range(1, n + 1) | |
] | |
return sliders | |
def model2json(): | |
fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"] | |
results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | { | |
"Panner": fx[7].pan.toJSON() | |
} | |
spatial_fx = { | |
"DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()}, | |
"FDN": fx[7].effects[1].toJSON() | |
| { | |
"Tone correction PEQ": { | |
k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq) | |
} | |
}, | |
"Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(), | |
} | |
return json.dumps( | |
{ | |
"Direct": results, | |
"Sends": spatial_fx, | |
} | |
) | |
def plot_eq(): | |
fig, ax = plt.subplots(figsize=(8, 4)) | |
w, eq_log_mags = get_log_mags_from_eq(fx[:6]) | |
ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
for i, eq_log_mag in enumerate(eq_log_mags): | |
ax.plot(w, eq_log_mag, "k-", alpha=0.3) | |
ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1) | |
ax.set_xlabel("Frequency (Hz)") | |
ax.set_ylabel("Magnitude (dB)") | |
ax.set_xlim(20, 20000) | |
ax.set_ylim(-40, 20) | |
ax.set_xscale("log") | |
ax.grid() | |
return fig | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
title_md, | |
elem_id="title", | |
) | |
with gr.Row(): | |
gr.Markdown( | |
description_md, | |
elem_id="description", | |
) | |
gr.Image("diffvox_diagram.png", elem_id="diagram") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
type="numpy", sources="upload", label="Input Audio", loop=True | |
) | |
with gr.Row(): | |
random_button = gr.Button( | |
f"Randomise PCs", | |
elem_id="randomise-button", | |
) | |
reset_button = gr.Button( | |
"Reset", | |
elem_id="reset-button", | |
) | |
render_button = gr.Button( | |
"Run", elem_id="render-button", variant="primary" | |
) | |
# random_rest_checkbox = gr.Checkbox( | |
# label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)", | |
# value=False, | |
# elem_id="randomise-checkbox", | |
# ) | |
sliders = get_important_pcs(NUMBER_OF_PCS, value=0) | |
extra_pc_dropdown = gr.Dropdown( | |
list(range(NUMBER_OF_PCS + 1, 76)), | |
label=f"PC > {NUMBER_OF_PCS}", | |
info="Select which extra PC to adjust", | |
interactive=True, | |
) | |
extra_slider = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="Extra PC", | |
value=0, | |
) | |
with gr.Column(): | |
audio_output = gr.Audio( | |
type="numpy", label="Output Audio", interactive=False, loop=True | |
) | |
peq_plot = gr.Plot( | |
plot_eq(), label="PEQ Frequency Response", elem_id="peq-plot" | |
) | |
with gr.Row(): | |
json_output = gr.JSON(label="Effect Settings", max_height=800, open=True) | |
render_button.click( | |
lambda *args: (lambda x: (x, model2json(), plot_eq()))(inference(*args)), | |
inputs=[ | |
audio_input, | |
# random_rest_checkbox, | |
] | |
# + sliders, | |
, | |
outputs=[audio_output, json_output, peq_plot], | |
) | |
random_button.click( | |
# lambda *xs: [ | |
# chain_functions( | |
# partial(max, SLIDER_MIN), | |
# partial(min, SLIDER_MAX), | |
# )(normalvariate(0, 1)) | |
# for _ in range(len(xs)) | |
# ], | |
# lambda i: (lambda x: x[:NUMBER_OF_PCS].tolist() + [x[i - 1].item()])( | |
# z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX) | |
# ), | |
chain_functions( | |
lambda i: (z.normal_(0, 1).clip_(SLIDER_MIN, SLIDER_MAX), i), | |
lambda args: args + (z2fx(),), | |
lambda args: args[0][:NUMBER_OF_PCS].tolist() | |
+ [args[0][args[1] - 1].item(), plot_eq()], | |
), | |
inputs=extra_pc_dropdown, | |
outputs=sliders + [extra_slider, peq_plot], | |
) | |
reset_button.click( | |
# lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()), | |
lambda: chain_functions( | |
lambda _: z.zero_(), | |
lambda _: z2fx(), | |
lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)] + [plot_eq()], | |
)(None), | |
# inputs=sliders + [extra_slider], | |
outputs=sliders + [extra_slider, peq_plot], | |
) | |
def update_z(s, i): | |
z[i] = s | |
return | |
for i, slider in enumerate(sliders): | |
slider.input( | |
chain_functions( | |
partial(update_z, i=i), | |
lambda _: z2fx(), | |
lambda _: plot_eq(), | |
), | |
inputs=slider, | |
outputs=peq_plot, | |
) | |
extra_slider.input( | |
lambda *xs: chain_functions( | |
lambda args: update_z(args[0], args[1] - 1), | |
lambda _: z2fx(), | |
lambda _: plot_eq(), | |
)(xs), | |
inputs=[extra_slider, extra_pc_dropdown], | |
outputs=peq_plot, | |
) | |
extra_pc_dropdown.input( | |
lambda i: z[i - 1].item(), | |
inputs=extra_pc_dropdown, | |
outputs=extra_slider, | |
) | |
demo.launch() | |