Hugo Flores Garcia commited on
Commit
bfacd00
·
1 Parent(s): e251e23

pin audiotools version!

Browse files
Files changed (2) hide show
  1. scripts/utils/process_folder-c2f.py +0 -124
  2. setup.py +1 -1
scripts/utils/process_folder-c2f.py DELETED
@@ -1,124 +0,0 @@
1
-
2
- from audiotools import AudioSignal
3
- import torch
4
- from pathlib import Path
5
- import argbind
6
- from tqdm import tqdm
7
- import random
8
-
9
- from typing import List
10
-
11
- from collections import defaultdict
12
-
13
- def coarse2fine_infer(
14
- signal,
15
- model,
16
- vqvae,
17
- device,
18
- ):
19
- output = {}
20
- w = signal
21
- w = w.to(device)
22
- z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
23
-
24
- model.to(device)
25
- output["reconstructed"] = model.to_signal(z, vqvae).cpu()
26
-
27
- # make a full mask
28
- mask = torch.ones_like(z)
29
- mask[:, :model.n_conditioning_codebooks, :] = 0
30
-
31
- output["sampled"] = model.sample(
32
- codec=vqvae,
33
- time_steps=z.shape[-1],
34
- sampling_steps=12,
35
- start_tokens=z,
36
- mask=mask,
37
- temperature=0.85,
38
- top_k=None,
39
- sample="gumbel",
40
- typical_filtering=True,
41
- return_signal=True
42
- ).cpu()
43
-
44
- output["argmax"] = model.sample(
45
- codec=vqvae,
46
- time_steps=z.shape[-1],
47
- sampling_steps=1,
48
- start_tokens=z,
49
- mask=mask,
50
- temperature=1.0,
51
- top_k=None,
52
- sample="argmax",
53
- typical_filtering=True,
54
- return_signal=True
55
- ).cpu()
56
-
57
- return output
58
-
59
-
60
-
61
- @argbind.bind(without_prefix=True)
62
- def main(
63
- sources=[
64
- "/data/spotdl/audio/val", "/data/spotdl/audio/test"
65
- ],
66
- exp_name="noise_mode",
67
- model_paths=[
68
- "runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth",
69
- "runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth",
70
- ],
71
- model_keys=[
72
- "mask",
73
- "random",
74
- ],
75
- vqvae_path: str = "runs/codec-ckpt/codec.pth",
76
- device: str = "cuda",
77
- output_dir: str = ".",
78
- max_excerpts: int = 5000,
79
- duration: float = 3.0,
80
- ):
81
- from vampnet.modules.transformer import VampNet
82
- from lac.model.lac import LAC
83
-
84
- models = {
85
- k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
86
- }
87
- for model in models.values():
88
- model.eval()
89
- print(f"Loaded {len(models)} models.")
90
-
91
- vqvae = LAC.load(vqvae_path)
92
- vqvae.to(device)
93
- vqvae.eval()
94
- print("Loaded VQVAE.")
95
-
96
- output_dir = Path(output_dir) / f"{exp_name}-samples"
97
-
98
- from audiotools.data.datasets import AudioLoader, AudioDataset
99
-
100
- loader = AudioLoader(sources=sources)
101
- dataset = AudioDataset(loader,
102
- sample_rate=vqvae.sample_rate,
103
- duration=duration,
104
- n_examples=max_excerpts,
105
- without_replacement=True,
106
- )
107
- for i in tqdm(range(max_excerpts)):
108
- sig = dataset[i]["signal"]
109
- sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
110
-
111
- for model_key, model in models.items():
112
- out = coarse2fine_infer(sig, model, vqvae, device)
113
- out_dir = output_dir / model_key / Path(sig.path_to_file).stem
114
- out_dir.mkdir(parents=True, exist_ok=True)
115
- for k, s in out.items():
116
- s.write(out_dir / f"{k}.wav")
117
-
118
-
119
- if __name__ == "__main__":
120
- args = argbind.parse_args()
121
-
122
- with argbind.scope(args):
123
- main()
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py CHANGED
@@ -28,7 +28,7 @@ setup(
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
- "audiotools @ git+https://github.com/descriptinc/audiotools.git",
32
  "dac @ git+https://github.com/descriptinc/descript-audio-codec.git",
33
  "gradio",
34
  "tensorboardX",
 
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
+ "audiotools @ git+https://github.com/descriptinc/audiotools.git@f35914b5b3c6f1bf589cd09481478d741538828e",
32
  "dac @ git+https://github.com/descriptinc/descript-audio-codec.git",
33
  "gradio",
34
  "tensorboardX",