sd-riffusion / app.py
juancopi81's picture
Add file download option
35afeb7
raw
history blame contribute delete
No virus
5.38 kB
import random
from PIL import Image
from diffusers import StableDiffusionPipeline
import gradio as gr
import torch
from spectro import wav_bytes_from_spectrogram_image
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=dtype,
revision="fp16")
pipe = pipe.to(device)
model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype)
pipe2 = pipe2.to(device)
COLORS = [
["#ff0000", "#00ff00"],
["#00ff00", "#0000ff"],
["#0000ff", "#ff0000"],
]
title = """
<div style="text-align: center; max-width: 650px; margin: 0 auto 10px;">
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
<h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
</div>
<p style="text-align: center;font-size: 94%">
Duplicate this Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training:
<span style="display: flex;align-items: center;justify-content: center;height: 30px;">
<a href="https://huggingface.co/spaces/juancopi81/sd-riffusion?duplicate=true">
<img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>
</span>
</p>
<p style="text-align: center;font-size: 94%">
You can buy me a coffee to support this space:
<span style="display: flex;align-items: center;justify-content: center;height: 30px;">
<a href="https://www.buymeacoffee.com/juancopi81j">
<img src="https://badgen.net/badge/icon/Buy%20Me%20A%20Coffee?icon=buymeacoffee&label" alt="Buy me a coffee"></a>. Depending on the support, I'll keep this space running and add more features!
</span>
</p>
</div>
"""
def get_bg_image(prompt):
images = pipe(prompt)
print("Image generated!")
image_output = images.images[0]
image_output.save("img.png")
return "img.png"
def get_music(prompt):
duration = 10
if duration == 5:
width_duration=512
else :
width_duration = 512 + ((int(duration)-5) * 128)
spec = pipe2(prompt, height=512, width=width_duration).images[0]
print(spec)
wav = wav_bytes_from_spectrogram_image(spec)
with open("output.wav", "wb") as f:
f.write(wav[0].getbuffer())
return "output.wav"
def infer(prompt, style):
style_prompt = prompt + style
image = get_bg_image(style_prompt)
audio = get_music(prompt)
video = gr.make_waveform(audio,
bg_image=image,
bars_color=random.choice(COLORS))
return video, video
css = """
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
#prompt-in {
border: 2px solid #666;
border-radius: 2px;
padding: 8px;
}
#prompt-style {
border: 2px solid #666;
border-radius: 2px;
padding: 8px;
}
#btn-container {
display: flex;
align-items: center;
justify-content: center;
width: calc(15% - 16px);
height: calc(15% - 16px);
}
/* Style the submit button */
#submit-btn {
background-color: #382a1d;
color: #fff;
border: 1px solid #000;
border-radius: 4px;
padding: 8px;
font-size: 16px;
cursor: pointer;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Column(elem_id="col-container"):
prompt_input = gr.Textbox(placeholder="The Beatles playing for the queen",
elem_id="prompt-in",
label="Enter your music prompt.")
style_input = gr.Textbox(placeholder="In the style of Vincent van Gogh",
elem_id="prompt-style",
label="(Optional) Add styles to your background image.",
value="")
with gr.Row(elem_id="btn-container"):
send_btn = gr.Button(value="Send", elem_id="submit-btn")
send_btn.click(infer,
inputs=[prompt_input, style_input],
outputs=[gr.Video(), gr.File()])
gr.Markdown("""
[![Twitter Follow](https://img.shields.io/twitter/follow/juancopi81?style=social)](https://twitter.com/juancopi81)
![visitors](https://visitor-badge.glitch.me/badge?page_id=Juancopi81.sd-riffusion)
""")
demo.launch(debug=True)