asigalov61 commited on
Commit
ce5443a
·
1 Parent(s): e091417

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -20,9 +20,18 @@ in_space = os.getenv("SYSTEM") == "spaces"
20
  #=================================================================================================
21
 
22
  @torch.no_grad()
23
- def GenerateMIDI(progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
24
 
25
- start_tokens = [3087, 3073+1, 3075+1]
26
  seq_len = 512
27
  max_seq_len = 2048
28
  temperature = 1.0
@@ -200,13 +209,14 @@ if __name__ == "__main__":
200
  "(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
201
  " for faster execution and endless generation"
202
  )
203
-
 
204
  run_btn = gr.Button("generate", variant="primary")
205
 
206
  output_midi_seq = gr.Variable()
207
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
208
  output_plot = gr.Plot(label="output plot")
209
  output_midi = gr.File(label="output midi", file_types=[".mid"])
210
- run_event = run_btn.click(GenerateMIDI, [], [output_midi_seq, output_plot, output_midi, output_audio])
211
 
212
  app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
20
  #=================================================================================================
21
 
22
  @torch.no_grad()
23
+ def GenerateMIDI(idrums, iinstr, progress=gr.Progress()):
24
+
25
+ if idrums:
26
+ drums = 3074
27
+ else:
28
+ drums = 3073
29
+
30
+ instruments_list = ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", 'Drums', "Choir", "Organ"]
31
+ first_note_instrument_number = instruments_list.index(iinstr)
32
+
33
+ start_tokens = [3087, drums, 3075+first_note_instrument_number]
34
 
 
35
  seq_len = 512
36
  max_seq_len = 2048
37
  temperature = 1.0
 
209
  "(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
210
  " for faster execution and endless generation"
211
  )
212
+ input_drums = gr.Checkbox(label="input drums", info="Drums present or not")
213
+ input_instrument = gr.Radio(["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"], label="input instrument", info="Desired lead instrument")
214
  run_btn = gr.Button("generate", variant="primary")
215
 
216
  output_midi_seq = gr.Variable()
217
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
218
  output_plot = gr.Plot(label="output plot")
219
  output_midi = gr.File(label="output midi", file_types=[".mid"])
220
+ run_event = run_btn.click(GenerateMIDI, [input_drums, input_instrument], [output_midi_seq, output_plot, output_midi, output_audio])
221
 
222
  app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)