diffvox / app.py
yoyolicoris's picture
Implement initial version of demo website
3044e63
raw
history blame
5.1 kB
import gradio as gr
import numpy as np
import torch
import yaml
import json
import pyloudnorm as pyln
from hydra.utils import instantiate
from random import normalvariate
from soxr import resample
from functools import partial
from src.modules.utils import chain_functions, vec2statedict, get_chunks
from src.modules.fx import clip_delay_eq_Q
SLIDER_MAX = 3
SLIDER_MIN = -3
NUMBER_OF_PCS = 10
TEMPERATURE = 0.7
CONFIG_PATH = "src/presets/rt_config.yaml"
PCA_PARAM_FILE = "src/presets/internal/gaussian.npz"
INFO_PATH = "src/presets/internal/info.json"
with open(CONFIG_PATH) as fp:
fx_config = yaml.safe_load(fp)["model"]
# append "src." to the module name
appendsrc = lambda d: (
{
k: (
f"src.{v}"
if (k == "_target_" and v.startswith("modules."))
else appendsrc(v)
)
for k, v in d.items()
}
if isinstance(d, dict)
else (list(map(appendsrc, d)) if isinstance(d, list) else d)
)
fx_config = appendsrc(fx_config) # type: ignore
fx = instantiate(fx_config)
fx.eval()
pca_params = np.load(PCA_PARAM_FILE)
mean = pca_params["mean"]
cov = pca_params["cov"]
eigvals, eigvecs = np.linalg.eigh(cov)
eigvals = np.flip(eigvals, axis=0)[:75]
eigvecs = np.flip(eigvecs, axis=1)[:, :75]
U = eigvecs * np.sqrt(eigvals)
U = torch.from_numpy(U).float()
mean = torch.from_numpy(mean).float()
with open(INFO_PATH) as f:
info = json.load(f)
param_keys = info["params_keys"]
original_shapes = list(
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
)
*vec2dict_args, _ = get_chunks(param_keys, original_shapes)
vec2dict_args = [param_keys, original_shapes] + vec2dict_args
vec2dict = partial(
vec2statedict,
**dict(
zip(
[
"keys",
"original_shapes",
"selected_chunks",
"position",
"U_matrix_shape",
],
vec2dict_args,
)
),
)
meter = pyln.Meter(44100)
@torch.no_grad()
def inference(audio, randomise_rest, *pcs):
sr, y = audio
if sr != 44100:
y = resample(y, sr, 44100)
if y.dtype.kind != "f":
y = y / 32768.0
if y.ndim == 1:
y = y[:, None]
loudness = meter.integrated_loudness(y)
y = pyln.normalize.loudness(y, loudness, -18.0)
y = torch.from_numpy(y).float().T.unsqueeze(0)
if y.shape[1] != 1:
y = y.mean(dim=1, keepdim=True)
M = eigvals.shape[0]
z = torch.cat(
[
torch.tensor([float(x) for x in pcs]),
(
torch.randn(M - len(pcs)) * TEMPERATURE
if randomise_rest
else torch.zeros(M - len(pcs))
),
]
)
x = U @ z + mean
fx.load_state_dict(vec2dict(x), strict=False)
fx.apply(partial(clip_delay_eq_Q, Q=0.707))
rendered = fx(y).squeeze(0).T.numpy()
if np.max(np.abs(rendered)) > 1:
rendered = rendered / np.max(np.abs(rendered))
return (44100, (rendered * 32768).astype(np.int16))
def get_important_pcs(n=10, **kwargs):
sliders = [
gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs)
for i in range(1, n + 1)
]
return sliders
with gr.Blocks() as demo:
gr.Markdown(
"""
# Hadamard Transform
This is a demo of the Hadamard transform.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="numpy", sources="upload", label="Input Audio")
with gr.Row():
random_button = gr.Button(
f"Randomise the first {NUMBER_OF_PCS} PCs",
elem_id="randomise-button",
)
reset_button = gr.Button(
"Reset",
elem_id="reset-button",
)
render_button = gr.Button(
"Run", elem_id="render-button", variant="primary"
)
random_rest_checkbox = gr.Checkbox(
label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)",
value=False,
elem_id="randomise-checkbox",
)
sliders = get_important_pcs(NUMBER_OF_PCS, value=0)
with gr.Column():
audio_output = gr.Audio(
type="numpy", label="Output Audio", interactive=False
)
render_button.click(
inference,
inputs=[
audio_input,
random_rest_checkbox,
]
+ sliders,
outputs=audio_output,
)
random_button.click(
lambda *xs: [
chain_functions(
partial(max, SLIDER_MIN),
partial(min, SLIDER_MAX),
)(normalvariate(0, 1))
for _ in range(len(xs))
],
inputs=sliders,
outputs=sliders,
)
reset_button.click(
lambda *xs: [0 for _ in range(len(xs))],
inputs=sliders,
outputs=sliders,
)
demo.launch()