Hugo Flores Garcia commited on
Commit
b61e699
1 Parent(s): 4d0cbfe

more demo ctrls

Browse files
Files changed (1) hide show
  1. demo.py +47 -2
demo.py CHANGED
@@ -104,12 +104,18 @@ def _vamp(data, return_mask=False):
104
  # save the mask as a txt file
105
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
106
 
 
107
  zv, mask_z = interface.coarse_vamp(
108
  z,
109
  mask=mask,
110
  sampling_steps=data[num_steps],
111
  temperature=(data[init_temp], data[final_temp]),
112
- return_mask=True
 
 
 
 
 
113
  )
114
 
115
  if use_coarse2fine:
@@ -299,6 +305,38 @@ with gr.Blocks() as demo:
299
  value=1.0
300
  )
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  num_steps = gr.Slider(
304
  label="number of steps (should normally be between 12 and 36)",
@@ -318,6 +356,8 @@ with gr.Blocks() as demo:
318
 
319
  vamp_button = gr.Button("vamp!!!")
320
 
 
 
321
  output_audio = gr.Audio(
322
  label="output audio",
323
  interactive=False,
@@ -373,7 +413,12 @@ with gr.Blocks() as demo:
373
  use_coarse2fine,
374
  stretch_factor,
375
  onset_mask_width,
376
- input_pitch_shift
 
 
 
 
 
377
  }
378
 
379
  # connect widgets
 
104
  # save the mask as a txt file
105
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
106
 
107
+ top_k = data[topk] if data[topk] > 0 else None
108
  zv, mask_z = interface.coarse_vamp(
109
  z,
110
  mask=mask,
111
  sampling_steps=data[num_steps],
112
  temperature=(data[init_temp], data[final_temp]),
113
+ return_mask=True,
114
+ sample=data[sampling_strategy],
115
+ typical_filtering=data[typical_filtering],
116
+ typical_mass=data[typical_mass],
117
+ typical_min_tokens=data[typical_min_tokens],
118
+ top_k=top_k,
119
  )
120
 
121
  if use_coarse2fine:
 
305
  value=1.0
306
  )
307
 
308
+ with gr.Accordion("sampling settings", open=False):
309
+ sampling_strategy = gr.Radio(
310
+ label="sampling strategy",
311
+ choices=["gumbel", "multinomial"],
312
+ value="gumbel"
313
+ )
314
+ typical_filtering = gr.Checkbox(
315
+ label="typical filtering (cannot be used with topk)",
316
+ value=True
317
+ )
318
+ typical_mass = gr.Slider(
319
+ label="typical mass (should probably stay between 0.1 and 0.5)",
320
+ minimum=0.01,
321
+ maximum=0.99,
322
+ value=0.2
323
+ )
324
+ typical_min_tokens = gr.Slider(
325
+ label="typical min tokens (should probably stay between 1 and 256)",
326
+ minimum=1,
327
+ maximum=256,
328
+ step=1,
329
+ value=1
330
+ )
331
+ topk = gr.Slider(
332
+ label="topk (cannot be used with typical filtering). 0 = None",
333
+ minimum=0,
334
+ maximum=256,
335
+ step=1,
336
+ value=0
337
+ )
338
+
339
+
340
 
341
  num_steps = gr.Slider(
342
  label="number of steps (should normally be between 12 and 36)",
 
356
 
357
  vamp_button = gr.Button("vamp!!!")
358
 
359
+ # mask settings
360
+ with gr.Column():
361
  output_audio = gr.Audio(
362
  label="output audio",
363
  interactive=False,
 
413
  use_coarse2fine,
414
  stretch_factor,
415
  onset_mask_width,
416
+ input_pitch_shift,
417
+ sampling_strategy,
418
+ typical_filtering,
419
+ typical_mass,
420
+ typical_min_tokens,
421
+ topk,
422
  }
423
 
424
  # connect widgets