Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
bf35d45
1
Parent(s):
8c3b3e7
lora interface
Browse files- app.py +50 -5
- vampnet/interface.py +2 -3
app.py
CHANGED
@@ -18,10 +18,45 @@ Interface = argbind.bind(Interface)
|
|
18 |
|
19 |
conf = argbind.parse_args()
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# dataset = at.data.datasets.AudioDataset(
|
27 |
# loader,
|
@@ -55,6 +90,8 @@ def load_example_audio():
|
|
55 |
|
56 |
|
57 |
def _vamp(data, return_mask=False):
|
|
|
|
|
58 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
59 |
out_dir.mkdir()
|
60 |
sig = at.AudioSignal(data[input_audio])
|
@@ -173,6 +210,7 @@ def save_vamp(data):
|
|
173 |
"use_coarse2fine": data[use_coarse2fine],
|
174 |
"stretch_factor": data[stretch_factor],
|
175 |
"seed": data[seed],
|
|
|
176 |
}
|
177 |
|
178 |
# save with yaml
|
@@ -472,6 +510,13 @@ with gr.Blocks() as demo:
|
|
472 |
|
473 |
# mask settings
|
474 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
vamp_button = gr.Button("generate (vamp)!!!")
|
476 |
output_audio = gr.Audio(
|
477 |
label="output audio",
|
@@ -514,7 +559,7 @@ with gr.Blocks() as demo:
|
|
514 |
beat_mask_width,
|
515 |
beat_mask_downbeats,
|
516 |
seed,
|
517 |
-
|
518 |
}
|
519 |
|
520 |
# connect widgets
|
|
|
18 |
|
19 |
conf = argbind.parse_args()
|
20 |
|
21 |
+
def load_interface():
|
22 |
+
with argbind.scope(conf):
|
23 |
+
interface = Interface()
|
24 |
+
# loader = AudioLoader()
|
25 |
+
print(f"interface device is {interface.device}")
|
26 |
+
return interface
|
27 |
+
|
28 |
+
|
29 |
+
LORA_NONE = "None"
|
30 |
+
def load_loras():
|
31 |
+
loras = {}
|
32 |
+
# find confs under conf/generated
|
33 |
+
for conf_file in Path("conf/generated").glob("**/interface.yml"):
|
34 |
+
name = conf_file.parent.name
|
35 |
+
with open(conf_file) as f:
|
36 |
+
loras[name] = yaml.safe_load(f)
|
37 |
+
loras[LORA_NONE] = None
|
38 |
+
return loras
|
39 |
+
|
40 |
+
interface = load_interface()
|
41 |
+
loras = load_loras()
|
42 |
+
cur_lora = LORA_NONE
|
43 |
+
|
44 |
+
def load_lora(name):
|
45 |
+
global interface
|
46 |
+
global cur_lora
|
47 |
+
if name == cur_lora:
|
48 |
+
return
|
49 |
+
if name != LORA_NONE:
|
50 |
+
interface.lora_load(
|
51 |
+
coarse_ckpt=loras[name]["Interface.coarse_lora_ckpt"],
|
52 |
+
c2f_ckpt=loras[name]["Interface.coarse2fine_lora_ckpt"],
|
53 |
+
full_ckpts=False
|
54 |
+
)
|
55 |
+
cur_lora = name
|
56 |
+
|
57 |
+
else:
|
58 |
+
interface = load_interface()
|
59 |
+
cur_lora = LORA_NONE
|
60 |
|
61 |
# dataset = at.data.datasets.AudioDataset(
|
62 |
# loader,
|
|
|
90 |
|
91 |
|
92 |
def _vamp(data, return_mask=False):
|
93 |
+
load_lora(data[lora_choice])
|
94 |
+
|
95 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
96 |
out_dir.mkdir()
|
97 |
sig = at.AudioSignal(data[input_audio])
|
|
|
210 |
"use_coarse2fine": data[use_coarse2fine],
|
211 |
"stretch_factor": data[stretch_factor],
|
212 |
"seed": data[seed],
|
213 |
+
"lora": data[lora_choice],
|
214 |
}
|
215 |
|
216 |
# save with yaml
|
|
|
510 |
|
511 |
# mask settings
|
512 |
with gr.Column():
|
513 |
+
|
514 |
+
lora_choice = gr.Dropdown(
|
515 |
+
label="lora choice",
|
516 |
+
choices=list(loras.keys()),
|
517 |
+
value=LORA_NONE,
|
518 |
+
)
|
519 |
+
|
520 |
vamp_button = gr.Button("generate (vamp)!!!")
|
521 |
output_audio = gr.Audio(
|
522 |
label="output audio",
|
|
|
559 |
beat_mask_width,
|
560 |
beat_mask_downbeats,
|
561 |
seed,
|
562 |
+
lora_choice,
|
563 |
}
|
564 |
|
565 |
# connect widgets
|
vampnet/interface.py
CHANGED
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
-
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
-
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
133 |
-
|
134 |
def s2t(self, seconds: float):
|
135 |
"""seconds to tokens"""
|
136 |
if isinstance(seconds, np.ndarray):
|
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
+
print(f"loading coarse from {coarse_ckpt}")
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
+
print(f"loading c2f from {c2f_ckpt}")
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
|
|
133 |
def s2t(self, seconds: float):
|
134 |
"""seconds to tokens"""
|
135 |
if isinstance(seconds, np.ndarray):
|