Spaces:
Running
on
Zero
Running
on
Zero
ft model
Browse files
app.py
CHANGED
@@ -100,9 +100,9 @@ def get_duration(model_name, tab, mid_seq, continuation_state, instruments, drum
|
|
100 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
101 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
102 |
if "large" in model_name:
|
103 |
-
return gen_events // 10 +
|
104 |
else:
|
105 |
-
return gen_events // 20 +
|
106 |
|
107 |
|
108 |
@spaces.GPU(duration=get_duration)
|
@@ -110,7 +110,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
|
|
110 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
111 |
gen_events, temp, top_p, top_k, allow_cc):
|
112 |
model = models[model_name]
|
113 |
-
model.to(device=opt.device
|
114 |
tokenizer = model.tokenizer
|
115 |
bpm = int(bpm)
|
116 |
if time_sig == "auto":
|
@@ -302,8 +302,8 @@ if __name__ == "__main__":
|
|
302 |
"generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
|
303 |
"generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
|
304 |
"generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
|
305 |
-
"j-pop finetune model (
|
306 |
-
"touhou finetune model (
|
307 |
}
|
308 |
models = {}
|
309 |
if opt.device == "cuda":
|
@@ -315,6 +315,7 @@ if __name__ == "__main__":
|
|
315 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
316 |
state_dict = ckpt.get("state_dict", ckpt)
|
317 |
model.load_state_dict(state_dict, strict=False)
|
|
|
318 |
models[name] = model
|
319 |
|
320 |
load_javascript()
|
|
|
100 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
101 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
102 |
if "large" in model_name:
|
103 |
+
return gen_events // 10 + 15
|
104 |
else:
|
105 |
+
return gen_events // 20 + 15
|
106 |
|
107 |
|
108 |
@spaces.GPU(duration=get_duration)
|
|
|
110 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
111 |
gen_events, temp, top_p, top_k, allow_cc):
|
112 |
model = models[model_name]
|
113 |
+
model.to(device=opt.device)
|
114 |
tokenizer = model.tokenizer
|
115 |
bpm = int(bpm)
|
116 |
if time_sig == "auto":
|
|
|
302 |
"generic pretrain model (tv2o-large) by asigalov61": ["asigalov61/Music-Llama", "", "tv2o-large"],
|
303 |
"generic pretrain model (tv2o-medium) by asigalov61": ["asigalov61/Music-Llama-Medium", "", "tv2o-medium"],
|
304 |
"generic pretrain model (tv1-medium) by skytnt": ["skytnt/midi-model", "", "tv1-medium"],
|
305 |
+
"j-pop finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "jpop-tv2o-medium/", "tv2o-medium"],
|
306 |
+
"touhou finetune model (tv2o-medium) by skytnt": ["skytnt/midi-model-ft", "touhou-tv2o-medium/", "tv2o-medium"],
|
307 |
}
|
308 |
models = {}
|
309 |
if opt.device == "cuda":
|
|
|
315 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
316 |
state_dict = ckpt.get("state_dict", ckpt)
|
317 |
model.load_state_dict(state_dict, strict=False)
|
318 |
+
model.to(device="cpu", dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32)
|
319 |
models[name] = model
|
320 |
|
321 |
load_javascript()
|