asigalov61 commited on
Commit
9c77576
1 Parent(s): b44558c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -21,13 +21,31 @@ from midi_to_colab_audio import midi_to_colab_audio
21
  # =================================================================================================
22
 
23
  @spaces.GPU
24
- def Generate_Rock_Song(input_midi, input_melody_seed_number):
 
 
 
 
 
25
 
26
  print('=' * 70)
27
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
28
  start_time = reqtime.time()
29
  print('=' * 70)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  print('Loading model...')
32
 
33
  SEQ_LEN = 4096
@@ -65,16 +83,6 @@ def Generate_Rock_Song(input_midi, input_melody_seed_number):
65
 
66
  print('Done!')
67
  print('=' * 70)
68
-
69
- #==================================================================
70
-
71
- fn = os.path.basename(input_midi)
72
- fn1 = fn.split('.')[0]
73
-
74
- print('=' * 70)
75
- print('Requested settings:')
76
- print('=' * 70)
77
- print('Input MIDI file name:', fn)
78
 
79
  #===============================================================================
80
  # Raw single-track ms score
@@ -174,6 +182,25 @@ def Generate_Rock_Song(input_midi, input_melody_seed_number):
174
 
175
  #==================================================================
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def generate_tokens(seq, max_num_ptcs=10):
178
 
179
  input = copy.deepcopy(seq)
@@ -185,7 +212,7 @@ def Generate_Rock_Song(input_midi, input_melody_seed_number):
185
 
186
  while pcount < max_num_ptcs and y > 255:
187
 
188
- x = torch.tensor(input, dtype=torch.long, device='cuda')
189
 
190
  with ctx:
191
  out = model.generate(x,
@@ -371,7 +398,11 @@ if __name__ == "__main__":
371
  output_plot = gr.Plot(label="Output MIDI score plot")
372
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
373
 
374
- run_event = run_btn.click(Generate_Rock_Song, [input_midi, input_melody_seed_number],
 
 
 
 
375
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
376
 
377
  gr.Examples(
 
21
  # =================================================================================================
22
 
23
  @spaces.GPU
24
+ def Generate_Rock_Song(input_midi,
25
+ input_freestyle_continuation,
26
+ input_number_prime_chords,
27
+ input_use_original_durations,
28
+ input_match_original_pitches_counts
29
+ ):
30
 
31
  print('=' * 70)
32
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
33
  start_time = reqtime.time()
34
  print('=' * 70)
35
 
36
+ fn = os.path.basename(input_midi)
37
+ fn1 = fn.split('.')[0]
38
+
39
+ print('=' * 70)
40
+ print('Requested settings:')
41
+ print('=' * 70)
42
+ print('Input MIDI file name:', fn)
43
+ print('Freestyle continuation:', input_freestyle_continuation)
44
+ print('Number of prime chords:', input_number_prime_chords)
45
+ print('Use original durations:', input_use_original_durations)
46
+ print('Match original pitches counts:', input_match_original_pitches_counts)
47
+ print('=' * 70)
48
+
49
  print('Loading model...')
50
 
51
  SEQ_LEN = 4096
 
83
 
84
  print('Done!')
85
  print('=' * 70)
 
 
 
 
 
 
 
 
 
 
86
 
87
  #===============================================================================
88
  # Raw single-track ms score
 
182
 
183
  #==================================================================
184
 
185
+ def generate_continuation(num_prime_tokens, num_gen_tokens):
186
+
187
+ x = torch.tensor(prime_toks[:num_prime_tokens], dtype=torch.long, device=DEVICE)
188
+
189
+ with ctx:
190
+ out = model.generate(x,
191
+ num_gen_tokens,
192
+ #filter_logits_fn=top_k,
193
+ #filter_kwargs={'k': 5},
194
+ temperature=0.9,
195
+ return_prime=True,
196
+ verbose=True)
197
+
198
+ y = out.tolist()[0]
199
+
200
+ return y
201
+
202
+ #==================================================================
203
+
204
  def generate_tokens(seq, max_num_ptcs=10):
205
 
206
  input = copy.deepcopy(seq)
 
212
 
213
  while pcount < max_num_ptcs and y > 255:
214
 
215
+ x = torch.tensor(input, dtype=torch.long, device=DEVICE)
216
 
217
  with ctx:
218
  out = model.generate(x,
 
398
  output_plot = gr.Plot(label="Output MIDI score plot")
399
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
400
 
401
+ run_event = run_btn.click(Generate_Rock_Song, [input_freestyle_continuation,
402
+ input_number_prime_chords,
403
+ input_use_original_durations,
404
+ input_match_original_pitches_counts
405
+ ],
406
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
407
 
408
  gr.Examples(