asigalov61 commited on
Commit
25f28a8
1 Parent(s): 9ce2c1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateMIDI(num_tok, idrums, iinstr):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
@@ -126,6 +126,8 @@ def GenerateMIDI(num_tok, idrums, iinstr):
126
  with torch.inference_mode():
127
  out = model.module.generate(inp,
128
  1,
 
 
129
  temperature=0.9,
130
  return_prime=False,
131
  verbose=False)
@@ -207,12 +209,13 @@ if __name__ == "__main__":
207
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
208
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
209
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
 
210
 
211
  run_btn = gr.Button("generate", variant="primary")
212
-
213
- output_plot = gr.Plot(label='output plot')
214
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
 
215
  output_midi = gr.File(label="output midi", file_types=[".mid"])
216
- run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
217
  [output_plot, output_midi, output_audio])
218
  app.queue().launch()
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_value):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
 
126
  with torch.inference_mode():
127
  out = model.module.generate(inp,
128
  1,
129
+ filter_logits_fn=top_k,
130
+ filter_kwargs={'k': input_top_k_value},
131
  temperature=0.9,
132
  return_prime=False,
133
  verbose=False)
 
209
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
210
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
211
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
212
+ input_top_k_value = gr.Slider(1, 100, value=15, label="Model sampling top_k value")
213
 
214
  run_btn = gr.Button("generate", variant="primary")
215
+
 
216
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
217
+ output_plot = gr.Plot(label='output plot')
218
  output_midi = gr.File(label="output midi", file_types=[".mid"])
219
+ run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_top_k_value],
220
  [output_plot, output_midi, output_audio])
221
  app.queue().launch()