Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import torch | |
import yaml | |
import json | |
import pyloudnorm as pyln | |
from hydra.utils import instantiate | |
from random import normalvariate | |
from soxr import resample | |
from functools import partial | |
from src.modules.utils import chain_functions, vec2statedict, get_chunks | |
from src.modules.fx import clip_delay_eq_Q | |
SLIDER_MAX = 3 | |
SLIDER_MIN = -3 | |
NUMBER_OF_PCS = 10 | |
TEMPERATURE = 0.7 | |
CONFIG_PATH = "src/presets/rt_config.yaml" | |
PCA_PARAM_FILE = "src/presets/internal/gaussian.npz" | |
INFO_PATH = "src/presets/internal/info.json" | |
with open(CONFIG_PATH) as fp: | |
fx_config = yaml.safe_load(fp)["model"] | |
# append "src." to the module name | |
appendsrc = lambda d: ( | |
{ | |
k: ( | |
f"src.{v}" | |
if (k == "_target_" and v.startswith("modules.")) | |
else appendsrc(v) | |
) | |
for k, v in d.items() | |
} | |
if isinstance(d, dict) | |
else (list(map(appendsrc, d)) if isinstance(d, list) else d) | |
) | |
fx_config = appendsrc(fx_config) # type: ignore | |
fx = instantiate(fx_config) | |
fx.eval() | |
pca_params = np.load(PCA_PARAM_FILE) | |
mean = pca_params["mean"] | |
cov = pca_params["cov"] | |
eigvals, eigvecs = np.linalg.eigh(cov) | |
eigvals = np.flip(eigvals, axis=0)[:75] | |
eigvecs = np.flip(eigvecs, axis=1)[:, :75] | |
U = eigvecs * np.sqrt(eigvals) | |
U = torch.from_numpy(U).float() | |
mean = torch.from_numpy(mean).float() | |
with open(INFO_PATH) as f: | |
info = json.load(f) | |
param_keys = info["params_keys"] | |
original_shapes = list( | |
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
) | |
*vec2dict_args, _ = get_chunks(param_keys, original_shapes) | |
vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
vec2dict = partial( | |
vec2statedict, | |
**dict( | |
zip( | |
[ | |
"keys", | |
"original_shapes", | |
"selected_chunks", | |
"position", | |
"U_matrix_shape", | |
], | |
vec2dict_args, | |
) | |
), | |
) | |
meter = pyln.Meter(44100) | |
def inference(audio, randomise_rest, *pcs): | |
sr, y = audio | |
if sr != 44100: | |
y = resample(y, sr, 44100) | |
if y.dtype.kind != "f": | |
y = y / 32768.0 | |
if y.ndim == 1: | |
y = y[:, None] | |
loudness = meter.integrated_loudness(y) | |
y = pyln.normalize.loudness(y, loudness, -18.0) | |
y = torch.from_numpy(y).float().T.unsqueeze(0) | |
if y.shape[1] != 1: | |
y = y.mean(dim=1, keepdim=True) | |
M = eigvals.shape[0] | |
z = torch.cat( | |
[ | |
torch.tensor([float(x) for x in pcs]), | |
( | |
torch.randn(M - len(pcs)) * TEMPERATURE | |
if randomise_rest | |
else torch.zeros(M - len(pcs)) | |
), | |
] | |
) | |
x = U @ z + mean | |
fx.load_state_dict(vec2dict(x), strict=False) | |
fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
rendered = fx(y).squeeze(0).T.numpy() | |
if np.max(np.abs(rendered)) > 1: | |
rendered = rendered / np.max(np.abs(rendered)) | |
return (44100, (rendered * 32768).astype(np.int16)) | |
def get_important_pcs(n=10, **kwargs): | |
sliders = [ | |
gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs) | |
for i in range(1, n + 1) | |
] | |
return sliders | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Hadamard Transform | |
This is a demo of the Hadamard transform. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="numpy", sources="upload", label="Input Audio") | |
with gr.Row(): | |
random_button = gr.Button( | |
f"Randomise the first {NUMBER_OF_PCS} PCs", | |
elem_id="randomise-button", | |
) | |
reset_button = gr.Button( | |
"Reset", | |
elem_id="reset-button", | |
) | |
render_button = gr.Button( | |
"Run", elem_id="render-button", variant="primary" | |
) | |
random_rest_checkbox = gr.Checkbox( | |
label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)", | |
value=False, | |
elem_id="randomise-checkbox", | |
) | |
sliders = get_important_pcs(NUMBER_OF_PCS, value=0) | |
with gr.Column(): | |
audio_output = gr.Audio( | |
type="numpy", label="Output Audio", interactive=False | |
) | |
render_button.click( | |
inference, | |
inputs=[ | |
audio_input, | |
random_rest_checkbox, | |
] | |
+ sliders, | |
outputs=audio_output, | |
) | |
random_button.click( | |
lambda *xs: [ | |
chain_functions( | |
partial(max, SLIDER_MIN), | |
partial(min, SLIDER_MAX), | |
)(normalvariate(0, 1)) | |
for _ in range(len(xs)) | |
], | |
inputs=sliders, | |
outputs=sliders, | |
) | |
reset_button.click( | |
lambda *xs: [0 for _ in range(len(xs))], | |
inputs=sliders, | |
outputs=sliders, | |
) | |
demo.launch() | |