skytnt commited on
Commit
a021cba
1 Parent(s): ff0299c

fix get_duration

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -121,7 +121,15 @@ def send_msgs(msgs):
121
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
122
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
123
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
124
- t = 1e-4*gen_events**2 + 25
 
 
 
 
 
 
 
 
125
  if "large" in model_name:
126
  t *= 2
127
  return t
 
121
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
122
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
123
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
124
+ if tab == 0:
125
+ start_events = 1 + len(instruments)
126
+ elif tab == 1 and mid is not None:
127
+ start_events = midi_events
128
+ elif tab == 2 and mid_seq is not None:
129
+ start_events = len(mid_seq[0])
130
+ else:
131
+ start_events = 1
132
+ t = 8.5e-5 * (gen_events+start_events) ** 2 - 8.5e-5 * start_events ** 2 + 23
133
  if "large" in model_name:
134
  t *= 2
135
  return t