File size: 4,977 Bytes
41b9d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05d43c6
41b9d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from pathlib import Path
import time
import os
from contextlib import contextmanager
import random

import numpy as np
import audiotools as at
from audiotools import AudioSignal
import argbind
import shutil
import torch
import yaml


from vampnet.interface import Interface, signal_concat
from vampnet import mask as pmask

from ttutil import log

# TODO: incorporate discord bot (if mem allows)
# in a separate thread, send audio samples for listening
# and send back the results
# as well as the params for sampling
# also a command that lets you clear the current signal 
# if you want to start over


device = "cuda" if torch.cuda.is_available() else "cpu"

VAMPNET_DIR = Path(".").resolve()

@contextmanager
def chdir(path):
    old_dir = os.getcwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(old_dir)

def load_interface(model_choice="default") -> Interface:
    with chdir(VAMPNET_DIR):


        # populate the model choices with any interface.yml files in the generated confs
        MODEL_CHOICES = {
            "default": {
                "Interface.coarse_ckpt": "models/vampnet/coarse.pth", 
                "Interface.coarse2fine_ckpt": "models/vampnet/c2f.pth",
                "Interface.codec_ckpt": "models/vampnet/codec.pth",
            }
        }
        generated_confs = Path("conf/generated")
        for conf_file in generated_confs.glob("*/interface.yml"):
            with open(conf_file) as f:
                _conf = yaml.safe_load(f)

                # check if the coarse, c2f, and codec ckpts exist
                # otherwise, dont' add this model choice
                if not (
                    Path(_conf["Interface.coarse_ckpt"]).exists() and 
                    Path(_conf["Interface.coarse2fine_ckpt"]).exists() and 
                    Path(_conf["Interface.codec_ckpt"]).exists()
                ):
                    continue

                MODEL_CHOICES[conf_file.parent.name] = _conf

    interface = Interface(
        device=device, 
        coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
        coarse2fine_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
        codec_ckpt=MODEL_CHOICES[model_choice]["Interface.codec_ckpt"],
    )

    interface.model_choices = MODEL_CHOICES
    interface.to("cuda" if torch.cuda.is_available() else "cpu")
    return interface

def load_model(interface: Interface, model_choice: str):
    interface.reload(
        interface.model_choices[model_choice]["Interface.coarse_ckpt"],
        interface.model_choices[model_choice]["Interface.coarse2fine_ckpt"],
    )

def ez_variation(
        interface,
        sig: AudioSignal,
        seed: int = None, 
        model_choice: str = None,  
    ):
    t0 = time.time()
    
    if seed is None:
        seed = int(torch.randint(0, 2**32, (1,)).item())
    at.util.seed(seed)

    # reload the model if necessary
    if model_choice is not None:
        load_model(interface, model_choice)

    # SAMPLING MASK PARAMS, hard code for now, we'll prob want a more preset-ey thing for the actual thin
    # we probably honestly just want to oscillate between the same 4 presets
    # in a predictable order such that they have a predictable outcome
    periodic_p = random.choice([3])
    n_mask_codebooks = 3
    sampletemp = random.choice([1.0,])
    dropout = random.choice([0.0, 0.0])

    top_p = None # NOTE: top p may be the culprit behind the collapse into single pitches. 

    # parameters for the build_mask function
    build_mask_kwargs = dict(
        rand_mask_intensity=1.0,
        prefix_s=0.0,
        suffix_s=0.0,
        periodic_prompt=int(periodic_p),
        periodic_prompt2=int(periodic_p),
        periodic_prompt_width=1,
        _dropout=dropout,
        upper_codebook_mask=int(n_mask_codebooks), 
        upper_codebook_mask_2=int(n_mask_codebooks),
    )

    # parameters for the vamp function
    vamp_kwargs = dict(
        temperature=sampletemp,
        typical_filtering=True, 
        typical_mass=0.15, 
        typical_min_tokens=64, 
        top_p=top_p,
        seed=seed,
        sample_cutoff=1.0,
    )

    # save the mask as a txt file
    interface.set_chunk_size(10.0)
    sig, mask, codes = interface.vamp(
        sig, 
        batch_size=1,
        feedback_steps=1,
        time_stretch_factor=1,
        build_mask_kwargs=build_mask_kwargs,
        vamp_kwargs=vamp_kwargs,
        return_mask=True,
    )

    log(f"vamp took {time.time() - t0} seconds")
    return sig



def main():
    import tqdm

    interface = load_interface()
    sig = AudioSignal.excerpt("assets/example.wav", duration=7.0)
    sig = interface.preprocess(sig)
    sig.write('ttout/in.wav')
    insig = sig.clone()

    fdbk_every = 4
    fdbk = 0.5

    for i in tqdm.tqdm(range(1000)): 
        sig = ez_variation(interface, sig, model_choice="orchestral")
        sig.write(f'ttout/out{i}.wav')
    

if __name__ == "__main__":
    main()