from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline import torch from PIL import Image, ImageDraw import os import numpy as np from scipy.io.wavfile import read from share_btn import community_icon_html, loading_icon_html, share_js os.system('pip install gradio==3.15.0') import gradio as gr os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion') from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline from riffusion.riffusion.datatypes import PromptInput, InferenceInput from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image from PIL import Image import struct import random repo_id = "riffusion/riffusion-model-v1" model = RiffusionPipeline.from_pretrained( repo_id, revision="main", torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False), ) if torch.cuda.is_available(): model.to("cuda") model.enable_xformers_memory_efficient_attention() pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),) pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config) # pipe_inpaint.enable_xformers_memory_efficient_attention() if torch.cuda.is_available(): pipe_inpaint = pipe_inpaint.to("cuda") pipe_inpaint.enable_xformers_memory_efficient_attention() def get_init_image(image, overlap, feel): width, height = image.size init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") # Crop the right side of the original image with `overlap_width` cropped_img = image.crop((width - int(width*overlap), 0, width, height)) init_image.paste(cropped_img, (0, 0)) return init_image def get_mask(image, overlap): width, height = image.size mask = Image.new("RGB", (width, height), color="white") draw = ImageDraw.Draw(mask) draw.rectangle((0, 0, int(overlap * width), height), fill="black") return mask def i2i(prompt, steps, feel, seed): # return pipe_i2i( # prompt, # num_inference_steps=steps, # image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"), # ).images[0] prompt_input_start = PromptInput(prompt=prompt, seed=seed) prompt_input_end = PromptInput(prompt=prompt, seed=seed) return model.riffuse( inputs=InferenceInput( start=prompt_input_start, end=prompt_input_end, alpha=1.0, num_inference_steps=steps), init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") ) def outpaint(prompt, init_image, mask, steps): return pipe_inpaint( prompt, num_inference_steps=steps, image=init_image, mask_image=mask, ).images[0] def generate(prompt, steps, num_iterations, feel, seed): if seed == 0: seed = random.randint(0,4294967295) num_images = num_iterations overlap = 0.5 image_width, image_height = 512, 512 # dimensions of each output image total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width) # total width of the stitched image # Create a blank image with the desired dimensions stitched_image = Image.new("RGB", (total_width, image_height), color="white") # Initialize the x position for pasting the next image x_pos = 0 image = i2i(prompt, steps, feel, seed) for i in range(num_images): # Generate the prompt, initial image, and mask for this iteration init_image = get_init_image(image, overlap, feel) mask = get_mask(init_image, overlap) # Run the outpaint function to generate the output image steps = 25 image = outpaint(prompt, init_image, mask, steps) # Paste the output image onto the stitched image stitched_image.paste(image, (x_pos, 0)) # Update the x position for the next iteration x_pos += int((1 - overlap) * image_width) wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image) # mask = Image.new("RGB", (512, 512), color="white") # bg_image = outpaint(prompt, init_image, mask, steps) # bg_image.save("bg_image.png") init_image.save("bg_image.png") # return read(wav_bytes) with open("output.wav", "wb") as f: f.write(wav_bytes.read()) return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25)) ############################################### def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5): prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start) prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end) input = InferenceInput( start=prompt_input_start, end=prompt_input_end, alpha=alpha, num_inference_steps=steps, seed_image_id=feel, # mask_image_id="mask_beat_lines_80.png" ) image = model.riffuse(inputs=input, init_image=init_image) wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image) return wav_bytes, image def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0): """Generate a WAV file of length seconds using the Riffusion model. Args: length (int): Length of the WAV file in seconds, must be divisible by 5. prompt_start (str): Prompt to start with. prompt_end (str, optional): Prompt to end with. Defaults to prompt_start. overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2. """ # open the initial image and convert it to RGB init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB") if prompt_end is None: prompt_end = prompt_start if seed_start == 0: seed_start = random.randint(0,4294967295) if seed_end is None: seed_end = seed_start # one riffuse() generates 5 seconds of audio wav_list = [] for i in range(int(num_iterations)): alpha = i / (num_iterations - 1) print(alpha) wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha) wav_list.append(wav_bytes) init_image = image seed_start = seed_end seed_end = seed_start + 1 # return read(wav_bytes) # return wav_list_to_wav(wav_list) # mask = Image.new("RGB", (512, 512), color="white") # bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps) # bg_image.save("bg_image.png") init_image.save("bg_image.png") with open("output.wav", "wb") as f: f.write(wav_list_to_wav(wav_list)) return gr.make_waveform("output.wav", bg_image="bg_image.png") def wav_list_to_wav(wav_list): # remove headers from the WAV files data = [wav.read()[44:] for wav in wav_list] # concatenate the data concatenated_data = b"".join(data) # create a new RIFF header channels = 1 sample_rate = 44100 bytes_per_second = channels * sample_rate new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data)) # combine the header and data to create the final WAV file final_wav = new_header + concatenated_data return final_wav ############################################### def on_submit(prompt_1, prompt_2, feel, num_iterations, steps=25, seed=0): if prompt_1 == "": return None, gr.update(value="First prompt is required."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) if prompt_2 == "": return generate(prompt_1, steps, num_iterations, feel, seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) else: return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) def on_num_iterations_change(n, prompt_2): if n is None: return gr.update(value="") if prompt_2 != "": total_length = 5 * n else: total_length = 2.5 + 2.5 * n return gr.update(value=f"Total length: {total_length:.2f} seconds") css = ''' #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } ''' with gr.Blocks(css=css) as app: gr.Markdown("## Riffusion Demo") gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co/riffusion/riffusion-model-v1) model.
In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)
Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but transitions between sections are more abrupt.""") gr.Markdown(f"""Running on {"**GPU 🔥**" if torch.cuda.is_available() else f"**CPU 🥶**. For faster inference it is recommended to **upgrade to GPU in space's Settings**"}
[![Duplicate Space](https://bit.ly/3gLdBN6)](https://huggingface.co/spaces/anzorq/riffusion-demo?duplicate=true)""") with gr.Row(): with gr.Group(): with gr.Row(): prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt", elem_id="riff-prompt_1") prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end", elem_id="riff-prompt_2") with gr.Row(): steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section") num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections") with gr.Row(): feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel", elem_id="riff-feel") seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)", elem_id="riff-seed") btn_generate = gr.Button(value="Generate").style(full_width=True) info = gr.Markdown() with gr.Column(): video = gr.Video(elem_id="riff-video") with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_html, elem_id="share-btn-share-icon", visible=False) loading_icon = gr.HTML(loading_icon_html, elem_id="share-btn-loading-icon", visible=False) share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) inputs = [prompt_1, prompt_2, feel, num_iterations, steps, seed] outputs = [video, info, community_icon, loading_icon, share_button] num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info]) prompt_1.submit(on_submit, inputs, outputs) prompt_2.submit(on_submit, inputs, outputs) btn_generate.click(on_submit, inputs, outputs) share_button.click(None, [], [], _js=share_js) examples = gr.Examples( fn=on_submit, examples=[ ["typing", "dance beat", "og_beat", 10], ["synthwave", "jazz", "agile", 10], ["rap battle freestyle", "", "og_beat", 10], # ["techno club banger", "", "og_beat", 10], ["reggae dub beat", "sunset chill", "og_beat", 10], ["acoustic folk ballad", "", "agile", 10], ["blues guitar riff", "", "agile", 5], ["jazzy trumpet solo", "", "og_beat", 5], ["classical symphony orchestra", "", "vibes", 10], ["rock and roll power chord", "", "motorway", 5], ["soulful R&B love song", "", "marim", 10], ["country western twangy guitar", "", "agile", 10]], inputs=[prompt_1, prompt_2, feel, num_iterations], outputs=outputs, cache_examples=True) gr.HTML("""

Space by:
Twitter Follow
GitHub followers


Buy Me A Coffee

visitors

""") app.queue(max_size=250, concurrency_count=6).launch()