yoyolicoris commited on
Commit
3044e63
·
1 Parent(s): eb92285

Implement initial version of demo website

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import yaml
5
+ import json
6
+ import pyloudnorm as pyln
7
+ from hydra.utils import instantiate
8
+ from random import normalvariate
9
+ from soxr import resample
10
+ from functools import partial
11
+
12
+ from src.modules.utils import chain_functions, vec2statedict, get_chunks
13
+ from src.modules.fx import clip_delay_eq_Q
14
+
15
+ SLIDER_MAX = 3
16
+ SLIDER_MIN = -3
17
+ NUMBER_OF_PCS = 10
18
+ TEMPERATURE = 0.7
19
+ CONFIG_PATH = "src/presets/rt_config.yaml"
20
+ PCA_PARAM_FILE = "src/presets/internal/gaussian.npz"
21
+ INFO_PATH = "src/presets/internal/info.json"
22
+
23
+
24
+ with open(CONFIG_PATH) as fp:
25
+ fx_config = yaml.safe_load(fp)["model"]
26
+ # append "src." to the module name
27
+ appendsrc = lambda d: (
28
+ {
29
+ k: (
30
+ f"src.{v}"
31
+ if (k == "_target_" and v.startswith("modules."))
32
+ else appendsrc(v)
33
+ )
34
+ for k, v in d.items()
35
+ }
36
+ if isinstance(d, dict)
37
+ else (list(map(appendsrc, d)) if isinstance(d, list) else d)
38
+ )
39
+ fx_config = appendsrc(fx_config) # type: ignore
40
+
41
+ fx = instantiate(fx_config)
42
+ fx.eval()
43
+
44
+ pca_params = np.load(PCA_PARAM_FILE)
45
+ mean = pca_params["mean"]
46
+ cov = pca_params["cov"]
47
+ eigvals, eigvecs = np.linalg.eigh(cov)
48
+ eigvals = np.flip(eigvals, axis=0)[:75]
49
+ eigvecs = np.flip(eigvecs, axis=1)[:, :75]
50
+ U = eigvecs * np.sqrt(eigvals)
51
+ U = torch.from_numpy(U).float()
52
+ mean = torch.from_numpy(mean).float()
53
+
54
+
55
+ with open(INFO_PATH) as f:
56
+ info = json.load(f)
57
+
58
+ param_keys = info["params_keys"]
59
+ original_shapes = list(
60
+ map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"])
61
+ )
62
+
63
+ *vec2dict_args, _ = get_chunks(param_keys, original_shapes)
64
+ vec2dict_args = [param_keys, original_shapes] + vec2dict_args
65
+ vec2dict = partial(
66
+ vec2statedict,
67
+ **dict(
68
+ zip(
69
+ [
70
+ "keys",
71
+ "original_shapes",
72
+ "selected_chunks",
73
+ "position",
74
+ "U_matrix_shape",
75
+ ],
76
+ vec2dict_args,
77
+ )
78
+ ),
79
+ )
80
+
81
+
82
+ meter = pyln.Meter(44100)
83
+
84
+
85
+ @torch.no_grad()
86
+ def inference(audio, randomise_rest, *pcs):
87
+ sr, y = audio
88
+ if sr != 44100:
89
+ y = resample(y, sr, 44100)
90
+ if y.dtype.kind != "f":
91
+ y = y / 32768.0
92
+
93
+ if y.ndim == 1:
94
+ y = y[:, None]
95
+ loudness = meter.integrated_loudness(y)
96
+ y = pyln.normalize.loudness(y, loudness, -18.0)
97
+
98
+ y = torch.from_numpy(y).float().T.unsqueeze(0)
99
+ if y.shape[1] != 1:
100
+ y = y.mean(dim=1, keepdim=True)
101
+
102
+ M = eigvals.shape[0]
103
+ z = torch.cat(
104
+ [
105
+ torch.tensor([float(x) for x in pcs]),
106
+ (
107
+ torch.randn(M - len(pcs)) * TEMPERATURE
108
+ if randomise_rest
109
+ else torch.zeros(M - len(pcs))
110
+ ),
111
+ ]
112
+ )
113
+ x = U @ z + mean
114
+
115
+ fx.load_state_dict(vec2dict(x), strict=False)
116
+ fx.apply(partial(clip_delay_eq_Q, Q=0.707))
117
+
118
+ rendered = fx(y).squeeze(0).T.numpy()
119
+ if np.max(np.abs(rendered)) > 1:
120
+ rendered = rendered / np.max(np.abs(rendered))
121
+ return (44100, (rendered * 32768).astype(np.int16))
122
+
123
+
124
+ def get_important_pcs(n=10, **kwargs):
125
+ sliders = [
126
+ gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs)
127
+ for i in range(1, n + 1)
128
+ ]
129
+ return sliders
130
+
131
+
132
+ with gr.Blocks() as demo:
133
+ gr.Markdown(
134
+ """
135
+ # Hadamard Transform
136
+ This is a demo of the Hadamard transform.
137
+ """
138
+ )
139
+ with gr.Row():
140
+ with gr.Column():
141
+ audio_input = gr.Audio(type="numpy", sources="upload", label="Input Audio")
142
+ with gr.Row():
143
+ random_button = gr.Button(
144
+ f"Randomise the first {NUMBER_OF_PCS} PCs",
145
+ elem_id="randomise-button",
146
+ )
147
+ reset_button = gr.Button(
148
+ "Reset",
149
+ elem_id="reset-button",
150
+ )
151
+ render_button = gr.Button(
152
+ "Run", elem_id="render-button", variant="primary"
153
+ )
154
+ random_rest_checkbox = gr.Checkbox(
155
+ label=f"Randomise PCs > {NUMBER_OF_PCS} (default to zeros)",
156
+ value=False,
157
+ elem_id="randomise-checkbox",
158
+ )
159
+ sliders = get_important_pcs(NUMBER_OF_PCS, value=0)
160
+ with gr.Column():
161
+ audio_output = gr.Audio(
162
+ type="numpy", label="Output Audio", interactive=False
163
+ )
164
+
165
+ render_button.click(
166
+ inference,
167
+ inputs=[
168
+ audio_input,
169
+ random_rest_checkbox,
170
+ ]
171
+ + sliders,
172
+ outputs=audio_output,
173
+ )
174
+
175
+ random_button.click(
176
+ lambda *xs: [
177
+ chain_functions(
178
+ partial(max, SLIDER_MIN),
179
+ partial(min, SLIDER_MAX),
180
+ )(normalvariate(0, 1))
181
+ for _ in range(len(xs))
182
+ ],
183
+ inputs=sliders,
184
+ outputs=sliders,
185
+ )
186
+ reset_button.click(
187
+ lambda *xs: [0 for _ in range(len(xs))],
188
+ inputs=sliders,
189
+ outputs=sliders,
190
+ )
191
+
192
+ demo.launch()