skytnt commited on
Commit
7406325
·
1 Parent(s): e9642f9
Files changed (2) hide show
  1. app.py +3 -4
  2. midi_model.py +1 -1
app.py CHANGED
@@ -1,5 +1,3 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
-
3
  import spaces
4
  import random
5
  import argparse
@@ -7,6 +5,7 @@ import glob
7
  import json
8
  import os
9
  import time
 
10
 
11
  import gradio as gr
12
  import numpy as np
@@ -122,7 +121,7 @@ def send_msgs(msgs):
122
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
123
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
124
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
125
- t = 1e-4*gen_events**2 + 15
126
  if "large" in model_name:
127
  t *= 2
128
  return t
@@ -383,7 +382,7 @@ if __name__ == "__main__":
383
  with app:
384
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
385
  gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
386
- "Midi event transformer for music generation\n\n"
387
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
388
  "[Open In Colab]"
389
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
 
 
 
1
  import spaces
2
  import random
3
  import argparse
 
5
  import json
6
  import os
7
  import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
 
10
  import gradio as gr
11
  import numpy as np
 
121
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
122
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
123
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
124
+ t = 1e-4*gen_events**2 + 25
125
  if "large" in model_name:
126
  t *= 2
127
  return t
 
382
  with app:
383
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
384
  gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
385
+ "Midi event transformer for symbolic music generation\n\n"
386
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
387
  "[Open In Colab]"
388
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
midi_model.py CHANGED
@@ -125,7 +125,7 @@ class MIDIModel(nn.Module):
125
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
126
  elif prompt.shape[0] == 1:
127
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
128
- else:
129
  raise ValueError(f"invalid shape for prompt, {prompt.shape}")
130
  prompt = prompt[..., :max_token_seq]
131
  if prompt.shape[-1] < max_token_seq:
 
125
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
126
  elif prompt.shape[0] == 1:
127
  prompt = np.repeat(prompt, repeats=batch_size, axis=0)
128
+ elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
129
  raise ValueError(f"invalid shape for prompt, {prompt.shape}")
130
  prompt = prompt[..., :max_token_seq]
131
  if prompt.shape[-1] < max_token_seq: