Spaces:
Running
on
Zero
Running
on
Zero
changes
Browse files- app.py +3 -4
- midi_model.py +1 -1
app.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
-
|
| 3 |
import spaces
|
| 4 |
import random
|
| 5 |
import argparse
|
|
@@ -7,6 +5,7 @@ import glob
|
|
| 7 |
import json
|
| 8 |
import os
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import numpy as np
|
|
@@ -122,7 +121,7 @@ def send_msgs(msgs):
|
|
| 122 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
| 123 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
| 124 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 125 |
-
t = 1e-4*gen_events**2 +
|
| 126 |
if "large" in model_name:
|
| 127 |
t *= 2
|
| 128 |
return t
|
|
@@ -383,7 +382,7 @@ if __name__ == "__main__":
|
|
| 383 |
with app:
|
| 384 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
| 385 |
gr.Markdown("\n\n"
|
| 386 |
-
"Midi event transformer for music generation\n\n"
|
| 387 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
| 388 |
"[Open In Colab]"
|
| 389 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
|
|
|
|
|
|
|
|
|
| 1 |
import spaces
|
| 2 |
import random
|
| 3 |
import argparse
|
|
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import numpy as np
|
|
|
|
| 121 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
| 122 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
| 123 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 124 |
+
t = 1e-4*gen_events**2 + 25
|
| 125 |
if "large" in model_name:
|
| 126 |
t *= 2
|
| 127 |
return t
|
|
|
|
| 382 |
with app:
|
| 383 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
| 384 |
gr.Markdown("\n\n"
|
| 385 |
+
"Midi event transformer for symbolic music generation\n\n"
|
| 386 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
| 387 |
"[Open In Colab]"
|
| 388 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
midi_model.py
CHANGED
|
@@ -125,7 +125,7 @@ class MIDIModel(nn.Module):
|
|
| 125 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
| 126 |
elif prompt.shape[0] == 1:
|
| 127 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
| 128 |
-
|
| 129 |
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
| 130 |
prompt = prompt[..., :max_token_seq]
|
| 131 |
if prompt.shape[-1] < max_token_seq:
|
|
|
|
| 125 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
| 126 |
elif prompt.shape[0] == 1:
|
| 127 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
| 128 |
+
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
|
| 129 |
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
| 130 |
prompt = prompt[..., :max_token_seq]
|
| 131 |
if prompt.shape[-1] < max_token_seq:
|