Hugo Flores Garcia commited on
Commit
03f09ee
·
1 Parent(s): f3f4634

better sampling defaults

Browse files
Files changed (2) hide show
  1. demo.py +45 -53
  2. vampnet/modules/base.py +2 -2
demo.py CHANGED
@@ -115,7 +115,7 @@ def vamp(
115
  sig,
116
  temperature=(init_temp, final_temp),
117
  prefix_dur_s=prefix_s,
118
- suffix_dur_s=suffix_s,
119
  num_loops=num_vamps,
120
  downsample_factor=mask_periodic_amt,
121
  intensity=rand_mask_intensity,
@@ -199,9 +199,6 @@ with gr.Blocks() as demo:
199
  5. Listen to the generated audio
200
  6. If you noticed something you liked, write some notes, click the "save vamp" button, and copy the save code
201
 
202
-
203
-
204
-
205
  """)
206
  gr.Markdown("## Input Audio")
207
  with gr.Column():
@@ -211,12 +208,6 @@ with gr.Blocks() as demo:
211
  - mask hints are used to guide vampnet to generate audio that sounds like the original
212
  - the more hints you give, the more the generated audio will sound like the original
213
 
214
-
215
-
216
-
217
-
218
-
219
-
220
  """)
221
  with gr.Column():
222
  gr.Markdown("""
@@ -228,6 +219,7 @@ with gr.Blocks() as demo:
228
  - if you want a more "random" generation:
229
  - uncheck the beat sync button (or reduce the beat unmask duration)
230
  - increase the periodic unmasking to 16 or more
 
231
 
232
  """)
233
 
@@ -281,11 +273,11 @@ with gr.Blocks() as demo:
281
  with gr.Column():
282
 
283
  mask_periodic_amt = gr.Slider(
284
- label="periodic unmasking factor (provides a rhythmic, periodic hint). 0.0 means no hint, 2 means one hint every 2 timesteps, etc, 4 means one hint every 4 timesteps, etc.",
285
  minimum=0,
286
- maximum=32,
287
  step=1,
288
- value=16,
289
  )
290
 
291
 
@@ -296,32 +288,33 @@ with gr.Blocks() as demo:
296
  value=1.0
297
  )
298
 
299
- prefix_s = gr.Slider(
300
- label="prefix hint length (seconds)",
301
- minimum=0.0,
302
- maximum=10.0,
303
- value=0.0
304
- )
305
- suffix_s = gr.Slider(
306
- label="suffix hint length (seconds)",
307
- minimum=0.0,
308
- maximum=10.0,
309
- value=0.0
310
- )
311
-
312
 
313
- init_temp = gr.Slider(
314
- label="initial temperature (should probably stay between 0.6 and 1)",
315
- minimum=0.0,
316
- maximum=1.5,
317
- value=0.8
318
- )
319
- final_temp = gr.Slider(
320
- label="final temperature (should probably stay between 0.7 and 2)",
321
- minimum=0.0,
322
- maximum=2.0,
323
- value=0.9
324
- )
 
325
 
326
  use_beats = gr.Checkbox(
327
  label="use beat hints",
@@ -333,10 +326,9 @@ with gr.Blocks() as demo:
333
  minimum=4,
334
  maximum=128,
335
  step=1,
336
- value=24
337
  )
338
 
339
-
340
  vamp_button = gr.Button("vamp!!!")
341
 
342
  output_audio = gr.Audio(
@@ -365,7 +357,7 @@ with gr.Blocks() as demo:
365
  label="duration",
366
  minimum=0.0,
367
  maximum=3.0,
368
- value=0.1
369
  )
370
  with gr.Accordion("downbeat settings", open=False):
371
  mask_dwn_chk = gr.Checkbox(
@@ -392,19 +384,19 @@ with gr.Blocks() as demo:
392
  step=1
393
  )
394
 
395
- notes_text = gr.Textbox(
396
- label="type any notes about the generated audio here",
397
- value="",
398
- interactive=True
399
- )
400
- save_button = gr.Button("download vamp")
401
- download_file = gr.File(
402
- label="vamp to download will appear here",
403
- interactive=False
404
- )
405
-
406
 
407
- thank_you = gr.Markdown("")
408
 
409
 
410
  # connect widgets
 
115
  sig,
116
  temperature=(init_temp, final_temp),
117
  prefix_dur_s=prefix_s,
118
+ suffix_dur_s=prefix_s, # suffix should be same length as prefix
119
  num_loops=num_vamps,
120
  downsample_factor=mask_periodic_amt,
121
  intensity=rand_mask_intensity,
 
199
  5. Listen to the generated audio
200
  6. If you noticed something you liked, write some notes, click the "save vamp" button, and copy the save code
201
 
 
 
 
202
  """)
203
  gr.Markdown("## Input Audio")
204
  with gr.Column():
 
208
  - mask hints are used to guide vampnet to generate audio that sounds like the original
209
  - the more hints you give, the more the generated audio will sound like the original
210
 
 
 
 
 
 
 
211
  """)
212
  with gr.Column():
213
  gr.Markdown("""
 
219
  - if you want a more "random" generation:
220
  - uncheck the beat sync button (or reduce the beat unmask duration)
221
  - increase the periodic unmasking to 16 or more
222
+ - increase the temperatures!
223
 
224
  """)
225
 
 
273
  with gr.Column():
274
 
275
  mask_periodic_amt = gr.Slider(
276
+ label="periodic hint (0.0 means no hint, 2 means one hint every 2 timesteps, etc, 4 means one hint every 4 timesteps, etc)",
277
  minimum=0,
278
+ maximum=64,
279
  step=1,
280
+ value=19,
281
  )
282
 
283
 
 
288
  value=1.0
289
  )
290
 
291
+ with gr.Accordion("prefix/suffix hints", open=False):
292
+ prefix_s = gr.Slider(
293
+ label="prefix hint length (seconds)",
294
+ minimum=0.0,
295
+ maximum=10.0,
296
+ value=0.0
297
+ )
298
+ suffix_s = gr.Slider(
299
+ label="suffix hint length (seconds)",
300
+ minimum=0.0,
301
+ maximum=10.0,
302
+ value=0.0
303
+ )
304
 
305
+ with gr.Accordion("temperature settings", open=False):
306
+ init_temp = gr.Slider(
307
+ label="initial temperature (should probably stay between 0.6 and 1)",
308
+ minimum=0.0,
309
+ maximum=1.5,
310
+ value=0.8
311
+ )
312
+ final_temp = gr.Slider(
313
+ label="final temperature (should probably stay between 0.7 and 2)",
314
+ minimum=0.0,
315
+ maximum=2.0,
316
+ value=1.0
317
+ )
318
 
319
  use_beats = gr.Checkbox(
320
  label="use beat hints",
 
326
  minimum=4,
327
  maximum=128,
328
  step=1,
329
+ value=36
330
  )
331
 
 
332
  vamp_button = gr.Button("vamp!!!")
333
 
334
  output_audio = gr.Audio(
 
357
  label="duration",
358
  minimum=0.0,
359
  maximum=3.0,
360
+ value=0.07
361
  )
362
  with gr.Accordion("downbeat settings", open=False):
363
  mask_dwn_chk = gr.Checkbox(
 
384
  step=1
385
  )
386
 
387
+ notes_text = gr.Textbox(
388
+ label="type any notes about the generated audio here",
389
+ value="",
390
+ interactive=True
391
+ )
392
+ save_button = gr.Button("save vamp")
393
+ download_file = gr.File(
394
+ label="vamp to download will appear here",
395
+ interactive=False
396
+ )
397
+
398
 
399
+ thank_you = gr.Markdown("")
400
 
401
 
402
  # connect widgets
vampnet/modules/base.py CHANGED
@@ -181,7 +181,7 @@ class VampBase(at.ml.BaseModel):
181
  self,
182
  codec,
183
  time_steps: int = 400,
184
- sampling_steps: int = 12,
185
  start_tokens: Optional[torch.Tensor] = None,
186
  mask: Optional[torch.Tensor] = None,
187
  temperature: Union[float, Tuple[float, float]] = 0.8,
@@ -290,7 +290,7 @@ class VampBase(at.ml.BaseModel):
290
  self,
291
  codec,
292
  time_steps: int = 300,
293
- sampling_steps: int = 12,
294
  start_tokens: Optional[torch.Tensor] = None,
295
  mask: Optional[torch.Tensor] = None,
296
  temperature: Union[float, Tuple[float, float]] = 0.8,
 
181
  self,
182
  codec,
183
  time_steps: int = 400,
184
+ sampling_steps: int = 36,
185
  start_tokens: Optional[torch.Tensor] = None,
186
  mask: Optional[torch.Tensor] = None,
187
  temperature: Union[float, Tuple[float, float]] = 0.8,
 
290
  self,
291
  codec,
292
  time_steps: int = 300,
293
+ sampling_steps: int = 36,
294
  start_tokens: Optional[torch.Tensor] = None,
295
  mask: Optional[torch.Tensor] = None,
296
  temperature: Union[float, Tuple[float, float]] = 0.8,