riffusion-demo / app.py
anzorq's picture
Update app.py
484ad84
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
import torch
from PIL import Image, ImageDraw
import os
import numpy as np
from scipy.io.wavfile import read
from share_btn import community_icon_html, loading_icon_html, share_js
os.system('pip install gradio==3.15.0')
import gradio as gr
os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.riffusion.datatypes import PromptInput, InferenceInput
from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image
from PIL import Image
import struct
import random
repo_id = "riffusion/riffusion-model-v1"
model = RiffusionPipeline.from_pretrained(
repo_id,
revision="main",
torch_dtype=torch.float16,
safety_checker=lambda images, **kwargs: (images, False),
)
if torch.cuda.is_available():
model.to("cuda")
model.enable_xformers_memory_efficient_attention()
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),)
pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config)
# pipe_inpaint.enable_xformers_memory_efficient_attention()
if torch.cuda.is_available():
pipe_inpaint = pipe_inpaint.to("cuda")
pipe_inpaint.enable_xformers_memory_efficient_attention()
def get_init_image(image, overlap, feel):
width, height = image.size
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
# Crop the right side of the original image with `overlap_width`
cropped_img = image.crop((width - int(width*overlap), 0, width, height))
init_image.paste(cropped_img, (0, 0))
return init_image
def get_mask(image, overlap):
width, height = image.size
mask = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(mask)
draw.rectangle((0, 0, int(overlap * width), height), fill="black")
return mask
def i2i(prompt, steps, feel, seed):
# return pipe_i2i(
# prompt,
# num_inference_steps=steps,
# image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"),
# ).images[0]
prompt_input_start = PromptInput(prompt=prompt, seed=seed)
prompt_input_end = PromptInput(prompt=prompt, seed=seed)
return model.riffuse(
inputs=InferenceInput(
start=prompt_input_start,
end=prompt_input_end,
alpha=1.0,
num_inference_steps=steps),
init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
)
def outpaint(prompt, init_image, mask, steps):
return pipe_inpaint(
prompt,
num_inference_steps=steps,
image=init_image,
mask_image=mask,
).images[0]
def generate(prompt, steps, num_iterations, feel, seed):
if seed == 0:
seed = random.randint(0,4294967295)
num_images = num_iterations
overlap = 0.5
image_width, image_height = 512, 512 # dimensions of each output image
total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width) # total width of the stitched image
# Create a blank image with the desired dimensions
stitched_image = Image.new("RGB", (total_width, image_height), color="white")
# Initialize the x position for pasting the next image
x_pos = 0
image = i2i(prompt, steps, feel, seed)
for i in range(num_images):
# Generate the prompt, initial image, and mask for this iteration
init_image = get_init_image(image, overlap, feel)
mask = get_mask(init_image, overlap)
# Run the outpaint function to generate the output image
steps = 25
image = outpaint(prompt, init_image, mask, steps)
# Paste the output image onto the stitched image
stitched_image.paste(image, (x_pos, 0))
# Update the x position for the next iteration
x_pos += int((1 - overlap) * image_width)
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image)
# mask = Image.new("RGB", (512, 512), color="white")
# bg_image = outpaint(prompt, init_image, mask, steps)
# bg_image.save("bg_image.png")
init_image.save("bg_image.png")
# return read(wav_bytes)
with open("output.wav", "wb") as f:
f.write(wav_bytes.read())
return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25))
###############################################
def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5):
prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start)
prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end)
input = InferenceInput(
start=prompt_input_start,
end=prompt_input_end,
alpha=alpha,
num_inference_steps=steps,
seed_image_id=feel,
# mask_image_id="mask_beat_lines_80.png"
)
image = model.riffuse(inputs=input, init_image=init_image)
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
return wav_bytes, image
def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0):
"""Generate a WAV file of length seconds using the Riffusion model.
Args:
length (int): Length of the WAV file in seconds, must be divisible by 5.
prompt_start (str): Prompt to start with.
prompt_end (str, optional): Prompt to end with. Defaults to prompt_start.
overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2.
"""
# open the initial image and convert it to RGB
init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
if prompt_end is None:
prompt_end = prompt_start
if seed_start == 0:
seed_start = random.randint(0,4294967295)
if seed_end is None:
seed_end = seed_start
# one riffuse() generates 5 seconds of audio
wav_list = []
for i in range(int(num_iterations)):
alpha = i / (num_iterations - 1)
print(alpha)
wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha)
wav_list.append(wav_bytes)
init_image = image
seed_start = seed_end
seed_end = seed_start + 1
# return read(wav_bytes)
# return wav_list_to_wav(wav_list)
# mask = Image.new("RGB", (512, 512), color="white")
# bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps)
# bg_image.save("bg_image.png")
init_image.save("bg_image.png")
with open("output.wav", "wb") as f:
f.write(wav_list_to_wav(wav_list))
return gr.make_waveform("output.wav", bg_image="bg_image.png")
def wav_list_to_wav(wav_list):
# remove headers from the WAV files
data = [wav.read()[44:] for wav in wav_list]
# concatenate the data
concatenated_data = b"".join(data)
# create a new RIFF header
channels = 1
sample_rate = 44100
bytes_per_second = channels * sample_rate
new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data))
# combine the header and data to create the final WAV file
final_wav = new_header + concatenated_data
return final_wav
###############################################
def on_submit(prompt_1, prompt_2, feel, num_iterations, steps=25, seed=0):
if prompt_1 == "":
return None, gr.update(value="First prompt is required."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
if prompt_2 == "":
return generate(prompt_1, steps, num_iterations, feel, seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
else:
return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def on_num_iterations_change(n, prompt_2):
if n is None:
return gr.update(value="")
if prompt_2 != "":
total_length = 5 * n
else:
total_length = 2.5 + 2.5 * n
return gr.update(value=f"Total length: {total_length:.2f} seconds")
css = '''
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
with gr.Blocks(css=css) as app:
gr.Markdown("## Riffusion Demo")
gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co/riffusion/riffusion-model-v1) model.<br>
In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)<br>
Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but transitions between sections are more abrupt.""")
gr.Markdown(f"""Running on {"**GPU 🔥**" if torch.cuda.is_available() else f"**CPU 🥶**. For faster inference it is recommended to **upgrade to GPU in space's Settings**"}<br>
[![Duplicate Space](https://bit.ly/3gLdBN6)](https://huggingface.co/spaces/anzorq/riffusion-demo?duplicate=true)""")
with gr.Row():
with gr.Group():
with gr.Row():
prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt", elem_id="riff-prompt_1")
prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end", elem_id="riff-prompt_2")
with gr.Row():
steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section")
num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections")
with gr.Row():
feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel", elem_id="riff-feel")
seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)", elem_id="riff-seed")
btn_generate = gr.Button(value="Generate").style(full_width=True)
info = gr.Markdown()
with gr.Column():
video = gr.Video(elem_id="riff-video")
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, elem_id="share-btn-share-icon", visible=False)
loading_icon = gr.HTML(loading_icon_html, elem_id="share-btn-loading-icon", visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
inputs = [prompt_1, prompt_2, feel, num_iterations, steps, seed]
outputs = [video, info, community_icon, loading_icon, share_button]
num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info])
prompt_1.submit(on_submit, inputs, outputs)
prompt_2.submit(on_submit, inputs, outputs)
btn_generate.click(on_submit, inputs, outputs)
share_button.click(None, [], [], _js=share_js)
examples = gr.Examples(
fn=on_submit,
examples=[
["typing", "dance beat", "og_beat", 10],
["synthwave", "jazz", "agile", 10],
["rap battle freestyle", "", "og_beat", 10],
# ["techno club banger", "", "og_beat", 10],
["reggae dub beat", "sunset chill", "og_beat", 10],
["acoustic folk ballad", "", "agile", 10],
["blues guitar riff", "", "agile", 5],
["jazzy trumpet solo", "", "og_beat", 5],
["classical symphony orchestra", "", "vibes", 10],
["rock and roll power chord", "", "motorway", 5],
["soulful R&B love song", "", "marim", 10],
["country western twangy guitar", "", "agile", 10]],
inputs=[prompt_1, prompt_2, feel, num_iterations],
outputs=outputs,
cache_examples=True)
gr.HTML("""
<div style="border-top: 1px solid #303030;">
<br>
<p>Space by:<br>
<a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a><br>
<a href="https://github.com/qunash"><img alt="GitHub followers" src="https://img.shields.io/github/followers/qunash?style=social" alt="Github Follow"></a></p><br>
<a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 24px !important;width: 81px !important;" ></a><br><br>
<p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.riffusion-demo" alt="visitors"></p>
</div>
""")
app.queue(max_size=250, concurrency_count=6).launch()