fffiloni commited on
Commit
51aa8a2
1 Parent(s): 32bcd9d

added duration option

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -11,8 +11,12 @@ model_id = "riffusion/riffusion-model-v1"
11
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
  pipe = pipe.to("cuda")
13
 
14
- def predict(prompt):
15
- spec = pipe(prompt, height=512, width=768).images[0]
 
 
 
 
16
  print(spec)
17
  wav = wav_bytes_from_spectrogram_image(spec)
18
  with open("output.wav", "wb") as f:
@@ -132,6 +136,7 @@ with gr.Blocks(css=css) as demo:
132
  gr.HTML(title)
133
 
134
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
 
135
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
136
 
137
  with gr.Column(elem_id="col-container-2"):
@@ -146,7 +151,7 @@ with gr.Blocks(css=css) as demo:
146
 
147
  gr.HTML(article)
148
 
149
- send_btn.click(predict, inputs=[prompt_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
150
  share_button.click(None, [], [], _js=share_js)
151
 
152
  demo.queue(max_size=250).launch(debug=True)
 
11
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
  pipe = pipe.to("cuda")
13
 
14
+ def predict(prompt, duration):
15
+ if duration == 5:
16
+ width_duration=512
17
+ else :
18
+ width_duration = int(512 + ((duration-5) * 128))
19
+ spec = pipe(prompt, height=512, width=width_duration).images[0]
20
  print(spec)
21
  wav = wav_bytes_from_spectrogram_image(spec)
22
  with open("output.wav", "wb") as f:
 
136
  gr.HTML(title)
137
 
138
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
139
+ duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=20, step=1, value=8, elem_id="duration-slider")
140
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
141
 
142
  with gr.Column(elem_id="col-container-2"):
 
151
 
152
  gr.HTML(article)
153
 
154
+ send_btn.click(predict, inputs=[prompt_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
155
  share_button.click(None, [], [], _js=share_js)
156
 
157
  demo.queue(max_size=250).launch(debug=True)