Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
6f6fd13
1
Parent(s):
4687dd9
exps
Browse files- scripts/utils/process_folder-c2f.py +63 -73
- scripts/utils/vamp_folder.py +122 -0
- vampnet/interface.py +1 -0
- vampnet/modules/base.py +2 -2
scripts/utils/process_folder-c2f.py
CHANGED
@@ -15,57 +15,47 @@ def coarse2fine_infer(
|
|
15 |
model,
|
16 |
vqvae,
|
17 |
device,
|
18 |
-
signal_window=3,
|
19 |
-
signal_hop=1.5,
|
20 |
-
max_excerpts=20,
|
21 |
):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
mask
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
top_k=None,
|
63 |
-
sample="argmax",
|
64 |
-
typical_filtering=True,
|
65 |
-
return_signal=True
|
66 |
-
).cpu())
|
67 |
-
|
68 |
-
return output
|
69 |
|
70 |
|
71 |
@argbind.bind(without_prefix=True)
|
@@ -73,11 +63,10 @@ def main(
|
|
73 |
sources=[
|
74 |
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
|
75 |
],
|
76 |
-
audio_ext="mp3",
|
77 |
exp_name="noise_mode",
|
78 |
model_paths=[
|
79 |
-
"runs/c2f-exp-03.22.23/ckpt/mask/
|
80 |
-
"runs/c2f-exp-03.22.23/ckpt/random/
|
81 |
],
|
82 |
model_keys=[
|
83 |
"mask",
|
@@ -86,10 +75,11 @@ def main(
|
|
86 |
vqvae_path: str = "runs/codec-ckpt/codec.pth",
|
87 |
device: str = "cuda",
|
88 |
output_dir: str = ".",
|
|
|
|
|
89 |
):
|
90 |
from vampnet.modules.transformer import VampNet
|
91 |
from lac.model.lac import LAC
|
92 |
-
from audiotools.post import audio_zip
|
93 |
|
94 |
models = {
|
95 |
k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
|
@@ -105,26 +95,26 @@ def main(
|
|
105 |
|
106 |
output_dir = Path(output_dir) / f"{exp_name}-samples"
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
out_dir.mkdir(parents=True, exist_ok=True)
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
for model_key, model in models.items():
|
123 |
-
out = coarse2fine_infer(sig, model, vqvae, device)
|
124 |
-
for k, sig_list in out.items():
|
125 |
-
for i, s in enumerate(sig_list):
|
126 |
-
s.write(out_dir / f"{model_key}-{k}-{i}.wav")
|
127 |
-
|
128 |
|
129 |
if __name__ == "__main__":
|
130 |
args = argbind.parse_args()
|
|
|
15 |
model,
|
16 |
vqvae,
|
17 |
device,
|
|
|
|
|
|
|
18 |
):
|
19 |
+
output = {}
|
20 |
+
w = signal
|
21 |
+
w = w.to(device)
|
22 |
+
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
|
23 |
+
|
24 |
+
model.to(device)
|
25 |
+
output["reconstructed"] = model.to_signal(z, vqvae).cpu()
|
26 |
+
|
27 |
+
# make a full mask
|
28 |
+
mask = torch.ones_like(z)
|
29 |
+
mask[:, :model.n_conditioning_codebooks, :] = 0
|
30 |
+
|
31 |
+
output["sampled"] = model.sample(
|
32 |
+
codec=vqvae,
|
33 |
+
time_steps=z.shape[-1],
|
34 |
+
sampling_steps=12,
|
35 |
+
start_tokens=z,
|
36 |
+
mask=mask,
|
37 |
+
temperature=0.85,
|
38 |
+
top_k=None,
|
39 |
+
sample="gumbel",
|
40 |
+
typical_filtering=True,
|
41 |
+
return_signal=True
|
42 |
+
).cpu()
|
43 |
+
|
44 |
+
output["argmax"] = model.sample(
|
45 |
+
codec=vqvae,
|
46 |
+
time_steps=z.shape[-1],
|
47 |
+
sampling_steps=1,
|
48 |
+
start_tokens=z,
|
49 |
+
mask=mask,
|
50 |
+
temperature=1.0,
|
51 |
+
top_k=None,
|
52 |
+
sample="argmax",
|
53 |
+
typical_filtering=True,
|
54 |
+
return_signal=True
|
55 |
+
).cpu()
|
56 |
+
|
57 |
+
return output
|
58 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
@argbind.bind(without_prefix=True)
|
|
|
63 |
sources=[
|
64 |
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
|
65 |
],
|
|
|
66 |
exp_name="noise_mode",
|
67 |
model_paths=[
|
68 |
+
"runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth",
|
69 |
+
"runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth",
|
70 |
],
|
71 |
model_keys=[
|
72 |
"mask",
|
|
|
75 |
vqvae_path: str = "runs/codec-ckpt/codec.pth",
|
76 |
device: str = "cuda",
|
77 |
output_dir: str = ".",
|
78 |
+
max_excerpts: int = 5000,
|
79 |
+
duration: float = 3.0,
|
80 |
):
|
81 |
from vampnet.modules.transformer import VampNet
|
82 |
from lac.model.lac import LAC
|
|
|
83 |
|
84 |
models = {
|
85 |
k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
|
|
|
95 |
|
96 |
output_dir = Path(output_dir) / f"{exp_name}-samples"
|
97 |
|
98 |
+
from audiotools.data.datasets import AudioLoader, AudioDataset
|
99 |
+
|
100 |
+
loader = AudioLoader(sources=sources)
|
101 |
+
dataset = AudioDataset(loader,
|
102 |
+
sample_rate=vqvae.sample_rate,
|
103 |
+
duration=duration,
|
104 |
+
n_examples=max_excerpts,
|
105 |
+
without_replacement=True,
|
106 |
+
)
|
107 |
+
for i in tqdm(range(max_excerpts)):
|
108 |
+
sig = dataset[i]["signal"]
|
109 |
+
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
|
110 |
+
|
111 |
+
for model_key, model in models.items():
|
112 |
+
out = coarse2fine_infer(sig, model, vqvae, device)
|
113 |
+
out_dir = output_dir / model_key / Path(sig.path_to_file).stem
|
114 |
out_dir.mkdir(parents=True, exist_ok=True)
|
115 |
+
for k, s in out.items():
|
116 |
+
s.write(out_dir / f"{k}.wav")
|
117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
args = argbind.parse_args()
|
scripts/utils/vamp_folder.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import argbind
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from vampnet.interface import Interface
|
8 |
+
|
9 |
+
Interface = argbind.bind(Interface, positional=True)
|
10 |
+
|
11 |
+
def baseline(sig, interface):
|
12 |
+
return sig
|
13 |
+
|
14 |
+
def reconstructed(sig, interface):
|
15 |
+
return interface.to_signal(
|
16 |
+
interface.encode(sig)
|
17 |
+
)
|
18 |
+
|
19 |
+
def coarse2fine(sig, interface):
|
20 |
+
z = interface.encode(sig)
|
21 |
+
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
|
22 |
+
|
23 |
+
z = interface.coarse_to_fine(z)
|
24 |
+
return interface.to_signal(z)
|
25 |
+
|
26 |
+
def one_codebook(sig, interface):
|
27 |
+
z = interface.encode(sig)
|
28 |
+
|
29 |
+
mask = torch.zeros_like(z)
|
30 |
+
mask[:, 1:, :] = 1
|
31 |
+
|
32 |
+
zv = interface.coarse_vamp_v2(
|
33 |
+
sig, ext_mask=mask,
|
34 |
+
)
|
35 |
+
zv = interface.coarse_to_fine(zv)
|
36 |
+
|
37 |
+
return interface.to_signal(zv)
|
38 |
+
|
39 |
+
def four_codebooks_downsampled_4x(sig, interface):
|
40 |
+
zv = interface.coarse_vamp_v2(
|
41 |
+
sig, downsample_factor=4
|
42 |
+
)
|
43 |
+
zv = interface.coarse_to_fine(zv)
|
44 |
+
return interface.to_signal(zv)
|
45 |
+
|
46 |
+
def two_codebooks_downsampled_4x(sig, interface):
|
47 |
+
z = interface.encode(sig)
|
48 |
+
|
49 |
+
mask = torch.zeros_like(z)
|
50 |
+
mask[:, 2:, :] = 1
|
51 |
+
|
52 |
+
zv = interface.coarse_vamp_v2(
|
53 |
+
sig, ext_mask=mask, downsample_factor=4
|
54 |
+
)
|
55 |
+
zv = interface.coarse_to_fine(zv)
|
56 |
+
|
57 |
+
return interface.to_signal(zv)
|
58 |
+
|
59 |
+
def four_codebooks_downsampled_8x(sig, interface):
|
60 |
+
zv = interface.coarse_vamp_v2(
|
61 |
+
sig, downsample_factor=8
|
62 |
+
)
|
63 |
+
zv = interface.coarse_to_fine(zv)
|
64 |
+
return interface.to_signal(zv)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
SAMPLE_CONDS ={
|
71 |
+
"baseline": baseline,
|
72 |
+
"reconstructed": reconstructed,
|
73 |
+
"coarse2fine": coarse2fine,
|
74 |
+
"one_codebook": one_codebook,
|
75 |
+
"four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
|
76 |
+
"two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
|
77 |
+
"four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
@argbind.bind(without_prefix=True)
|
82 |
+
def main(
|
83 |
+
sources=[
|
84 |
+
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
|
85 |
+
],
|
86 |
+
output_dir: str = "./samples",
|
87 |
+
max_excerpts: int = 5000,
|
88 |
+
):
|
89 |
+
interface = Interface()
|
90 |
+
|
91 |
+
output_dir = Path(output_dir)
|
92 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
93 |
+
|
94 |
+
from audiotools.data.datasets import AudioLoader, AudioDataset
|
95 |
+
|
96 |
+
loader = AudioLoader(sources=sources)
|
97 |
+
dataset = AudioDataset(loader,
|
98 |
+
sample_rate=interface.codec.sample_rate,
|
99 |
+
duration=interface.coarse.chunk_size_s,
|
100 |
+
n_examples=max_excerpts,
|
101 |
+
without_replacement=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
for i in tqdm(range(max_excerpts)):
|
105 |
+
sig = dataset[i]["signal"]
|
106 |
+
|
107 |
+
results = {
|
108 |
+
name: cond(sig, interface)
|
109 |
+
for name, cond in SAMPLE_CONDS.items()
|
110 |
+
}
|
111 |
+
|
112 |
+
for name, sig in results.items():
|
113 |
+
output_dir = Path(output_dir) / name
|
114 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
115 |
+
|
116 |
+
sig.write(output_dir / f"{i}.wav")
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
args = argbind.parse_args()
|
120 |
+
|
121 |
+
with argbind.scope(args):
|
122 |
+
main()
|
vampnet/interface.py
CHANGED
@@ -196,6 +196,7 @@ class Interface:
|
|
196 |
time_steps=chunk_len,
|
197 |
start_tokens=chunk,
|
198 |
return_signal=False,
|
|
|
199 |
)
|
200 |
fine_z.append(chunk)
|
201 |
|
|
|
196 |
time_steps=chunk_len,
|
197 |
start_tokens=chunk,
|
198 |
return_signal=False,
|
199 |
+
**kwargs
|
200 |
)
|
201 |
fine_z.append(chunk)
|
202 |
|
vampnet/modules/base.py
CHANGED
@@ -288,12 +288,12 @@ class VampBase(at.ml.BaseModel):
|
|
288 |
self,
|
289 |
codec,
|
290 |
time_steps: int = 300,
|
291 |
-
sampling_steps: int =
|
292 |
start_tokens: Optional[torch.Tensor] = None,
|
293 |
mask: Optional[torch.Tensor] = None,
|
294 |
temperature: Union[float, Tuple[float, float]] = 0.8,
|
295 |
top_k: int = None,
|
296 |
-
sample: str = "
|
297 |
typical_filtering=False,
|
298 |
typical_mass=0.2,
|
299 |
typical_min_tokens=1,
|
|
|
288 |
self,
|
289 |
codec,
|
290 |
time_steps: int = 300,
|
291 |
+
sampling_steps: int = 12,
|
292 |
start_tokens: Optional[torch.Tensor] = None,
|
293 |
mask: Optional[torch.Tensor] = None,
|
294 |
temperature: Union[float, Tuple[float, float]] = 0.8,
|
295 |
top_k: int = None,
|
296 |
+
sample: str = "gumbel",
|
297 |
typical_filtering=False,
|
298 |
typical_mass=0.2,
|
299 |
typical_min_tokens=1,
|