File size: 3,304 Bytes
a63cce0
 
 
 
 
 
 
 
e4e3c4e
 
a63cce0
 
 
 
 
 
 
 
6f6fd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63cce0
 
 
 
 
e4e3c4e
a63cce0
 
 
6f6fd13
 
a63cce0
 
e4e3c4e
 
a63cce0
e4e3c4e
 
 
6f6fd13
 
a63cce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4e3c4e
 
6f6fd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4e3c4e
6f6fd13
 
 
a63cce0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

from audiotools import AudioSignal
import torch
from pathlib import Path
import argbind
from tqdm import tqdm
import random

from typing import List

from collections import defaultdict

def coarse2fine_infer(
        signal, 
        model, 
        vqvae, 
        device,
    ):
        output = {}
        w = signal
        w = w.to(device)
        z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]

        model.to(device)
        output["reconstructed"] = model.to_signal(z, vqvae).cpu()

        # make a full mask
        mask = torch.ones_like(z)
        mask[:, :model.n_conditioning_codebooks, :] = 0

        output["sampled"] = model.sample(
            codec=vqvae, 
            time_steps=z.shape[-1], 
            sampling_steps=12, 
            start_tokens=z, 
            mask=mask, 
            temperature=0.85, 
            top_k=None, 
            sample="gumbel", 
            typical_filtering=True, 
            return_signal=True
        ).cpu()

        output["argmax"] = model.sample(
            codec=vqvae, 
            time_steps=z.shape[-1], 
            sampling_steps=1, 
            start_tokens=z, 
            mask=mask, 
            temperature=1.0, 
            top_k=None, 
            sample="argmax", 
            typical_filtering=True, 
            return_signal=True
        ).cpu()

        return output



@argbind.bind(without_prefix=True)
def main(
        sources=[
            "/data/spotdl/audio/val", "/data/spotdl/audio/test"
        ], 
        exp_name="noise_mode",
        model_paths=[
            "runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth",
            "runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth",
        ],
        model_keys=[
            "mask",
            "random",
        ],
        vqvae_path: str = "runs/codec-ckpt/codec.pth",
        device: str = "cuda",
        output_dir: str = ".",
        max_excerpts: int = 5000,
        duration: float = 3.0,
    ):
    from vampnet.modules.transformer import VampNet
    from lac.model.lac import LAC

    models = {
        k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
    }
    for model in models.values(): 
        model.eval()
    print(f"Loaded {len(models)} models.")

    vqvae = LAC.load(vqvae_path)
    vqvae.to(device)
    vqvae.eval()
    print("Loaded VQVAE.")

    output_dir = Path(output_dir) / f"{exp_name}-samples"

    from audiotools.data.datasets import AudioLoader, AudioDataset

    loader = AudioLoader(sources=sources)
    dataset = AudioDataset(loader, 
        sample_rate=vqvae.sample_rate, 
        duration=duration, 
        n_examples=max_excerpts, 
        without_replacement=True,
    )
    for i in tqdm(range(max_excerpts)):
        sig = dataset[i]["signal"]
        sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)

        for model_key, model in models.items():
            out = coarse2fine_infer(sig, model, vqvae, device)
            out_dir = output_dir / model_key / Path(sig.path_to_file).stem
            out_dir.mkdir(parents=True, exist_ok=True)
            for k, s in out.items():
                s.write(out_dir / f"{k}.wav")
        

if __name__ == "__main__":
    args = argbind.parse_args()

    with argbind.scope(args):
        main()