juancopi81 commited on
Commit
877e056
1 Parent(s): f18cdd2
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -5,13 +5,13 @@ from diffusers import StableDiffusionPipeline
5
  import gradio as gr
6
  import torch
7
 
8
- from spectro import wav_bytes_from_spectrogram_image
9
-
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dtype = torch.float16 if device == "cuda" else torch.float32
12
 
13
  model_id = "runwayml/stable-diffusion-v1-5"
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
 
 
15
  pipe = pipe.to(device)
16
 
17
  model_id2 = "riffusion/riffusion-model-v1"
@@ -29,7 +29,6 @@ title = """
29
  <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
30
  <h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
31
  </div>
32
- <p style="margin-bottom: 10px; font-size: 98%; color: #666;">Text to music player.</p>
33
  </div>
34
  """
35
  def get_bg_image(prompt):
@@ -39,18 +38,18 @@ def get_bg_image(prompt):
39
  return image_output
40
 
41
  def get_music(prompt):
42
- spec = pipe2(prompt).images[0]
43
  print(spec)
44
  wav = wav_bytes_from_spectrogram_image(spec)
45
  with open("output.wav", "wb") as f:
46
- f.write(wav[0].getbuffer())
47
- return 'output.wav'
48
 
49
  def infer(prompt):
50
  image = get_bg_image(prompt)
51
  audio = get_music(prompt)
52
  return (
53
- gr.make_waveform(audio, bg_image=image, bars_color=random.choice(COLORS)),
54
  )
55
 
56
  css = """
@@ -81,12 +80,13 @@ css = """
81
  with gr.Blocks(css=css) as demo:
82
  gr.HTML(title)
83
  with gr.Column(elem_id="col-container"):
84
- prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club",
85
  elem_id="prompt-in",
86
- show_label=False)
87
  with gr.Row(elem_id="btn-container"):
88
  send_btn = gr.Button(value="Send", elem_id="submit-btn")
89
- video_output = gr.Video()
90
- send_btn.click(infer, inputs=[prompt_input], outputs=[video_output])
 
91
 
92
- demo.queue().launch(debug=True)
 
5
  import gradio as gr
6
  import torch
7
 
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.float16 if device == "cuda" else torch.float32
10
 
11
  model_id = "runwayml/stable-diffusion-v1-5"
12
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
13
+ torch_dtype=dtype,
14
+ revision="fp16")
15
  pipe = pipe.to(device)
16
 
17
  model_id2 = "riffusion/riffusion-model-v1"
 
29
  <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
30
  <h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
31
  </div>
 
32
  </div>
33
  """
34
  def get_bg_image(prompt):
 
38
  return image_output
39
 
40
  def get_music(prompt):
41
+ spec = pipe2(prompt, height=512, width=512).images[0]
42
  print(spec)
43
  wav = wav_bytes_from_spectrogram_image(spec)
44
  with open("output.wav", "wb") as f:
45
+ f.write(wav[0].getbuffer())
46
+ return "output.wav"
47
 
48
  def infer(prompt):
49
  image = get_bg_image(prompt)
50
  audio = get_music(prompt)
51
  return (
52
+ gr.make_waveform("output.wav", bg_image=image, bars_color=random.choice(COLORS)),
53
  )
54
 
55
  css = """
 
80
  with gr.Blocks(css=css) as demo:
81
  gr.HTML(title)
82
  with gr.Column(elem_id="col-container"):
83
+ prompt_input = gr.Textbox(placeholder="The Beatles playing for the queen",
84
  elem_id="prompt-in",
85
+ label="Enter your music prompt")
86
  with gr.Row(elem_id="btn-container"):
87
  send_btn = gr.Button(value="Send", elem_id="submit-btn")
88
+ send_btn.click(infer,
89
+ inputs=[prompt_input],
90
+ outputs=[gr.Video()])
91
 
92
+ demo.launch().debug(True)