Hugo Flores Garcia commited on
Commit
bf35d45
1 Parent(s): 8c3b3e7

lora interface

Browse files
Files changed (2) hide show
  1. app.py +50 -5
  2. vampnet/interface.py +2 -3
app.py CHANGED
@@ -18,10 +18,45 @@ Interface = argbind.bind(Interface)
18
 
19
  conf = argbind.parse_args()
20
 
21
- with argbind.scope(conf):
22
- interface = Interface()
23
- # loader = AudioLoader()
24
- print(f"interface device is {interface.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # dataset = at.data.datasets.AudioDataset(
27
  # loader,
@@ -55,6 +90,8 @@ def load_example_audio():
55
 
56
 
57
  def _vamp(data, return_mask=False):
 
 
58
  out_dir = OUT_DIR / str(uuid.uuid4())
59
  out_dir.mkdir()
60
  sig = at.AudioSignal(data[input_audio])
@@ -173,6 +210,7 @@ def save_vamp(data):
173
  "use_coarse2fine": data[use_coarse2fine],
174
  "stretch_factor": data[stretch_factor],
175
  "seed": data[seed],
 
176
  }
177
 
178
  # save with yaml
@@ -472,6 +510,13 @@ with gr.Blocks() as demo:
472
 
473
  # mask settings
474
  with gr.Column():
 
 
 
 
 
 
 
475
  vamp_button = gr.Button("generate (vamp)!!!")
476
  output_audio = gr.Audio(
477
  label="output audio",
@@ -514,7 +559,7 @@ with gr.Blocks() as demo:
514
  beat_mask_width,
515
  beat_mask_downbeats,
516
  seed,
517
- seed
518
  }
519
 
520
  # connect widgets
 
18
 
19
  conf = argbind.parse_args()
20
 
21
+ def load_interface():
22
+ with argbind.scope(conf):
23
+ interface = Interface()
24
+ # loader = AudioLoader()
25
+ print(f"interface device is {interface.device}")
26
+ return interface
27
+
28
+
29
+ LORA_NONE = "None"
30
+ def load_loras():
31
+ loras = {}
32
+ # find confs under conf/generated
33
+ for conf_file in Path("conf/generated").glob("**/interface.yml"):
34
+ name = conf_file.parent.name
35
+ with open(conf_file) as f:
36
+ loras[name] = yaml.safe_load(f)
37
+ loras[LORA_NONE] = None
38
+ return loras
39
+
40
+ interface = load_interface()
41
+ loras = load_loras()
42
+ cur_lora = LORA_NONE
43
+
44
+ def load_lora(name):
45
+ global interface
46
+ global cur_lora
47
+ if name == cur_lora:
48
+ return
49
+ if name != LORA_NONE:
50
+ interface.lora_load(
51
+ coarse_ckpt=loras[name]["Interface.coarse_lora_ckpt"],
52
+ c2f_ckpt=loras[name]["Interface.coarse2fine_lora_ckpt"],
53
+ full_ckpts=False
54
+ )
55
+ cur_lora = name
56
+
57
+ else:
58
+ interface = load_interface()
59
+ cur_lora = LORA_NONE
60
 
61
  # dataset = at.data.datasets.AudioDataset(
62
  # loader,
 
90
 
91
 
92
  def _vamp(data, return_mask=False):
93
+ load_lora(data[lora_choice])
94
+
95
  out_dir = OUT_DIR / str(uuid.uuid4())
96
  out_dir.mkdir()
97
  sig = at.AudioSignal(data[input_audio])
 
210
  "use_coarse2fine": data[use_coarse2fine],
211
  "stretch_factor": data[stretch_factor],
212
  "seed": data[seed],
213
+ "lora": data[lora_choice],
214
  }
215
 
216
  # save with yaml
 
510
 
511
  # mask settings
512
  with gr.Column():
513
+
514
+ lora_choice = gr.Dropdown(
515
+ label="lora choice",
516
+ choices=list(loras.keys()),
517
+ value=LORA_NONE,
518
+ )
519
+
520
  vamp_button = gr.Button("generate (vamp)!!!")
521
  output_audio = gr.Audio(
522
  label="output audio",
 
559
  beat_mask_width,
560
  beat_mask_downbeats,
561
  seed,
562
+ lora_choice,
563
  }
564
 
565
  # connect widgets
vampnet/interface.py CHANGED
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
-
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
-
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
133
-
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
136
  if isinstance(seconds, np.ndarray):
 
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
+ print(f"loading coarse from {coarse_ckpt}")
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
+ print(f"loading c2f from {c2f_ckpt}")
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
 
133
  def s2t(self, seconds: float):
134
  """seconds to tokens"""
135
  if isinstance(seconds, np.ndarray):