# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/app.py import spaces import os import json import torch import random import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from PIL import Image from unet2d_custom import UNet2DConditionModel import torch from pipeline_stable_diffusion_custom import StableDiffusionPipeline from diffusers import DDIMScheduler from pnp_utils import * import torchvision.transforms as T from preprocess import get_timesteps from preprocess import Preprocess from pnp import PNP sample_idx = 0 css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 1.5em; min-width: 1.5em !important; height: 1.5em; } """ class AnimateController: def __init__(self): self.sr = 44100 self.save_steps = 50 self.device = 'cuda' self.seed = 42 self.extract_reverse = False self.save_dir = 'latents' self.steps = 50 self.inversion_prompt = '' self.seed = 42 seed_everything(self.seed) self.pnp = PNP(sd_version="1.4") self.pnp.unet.to(self.device) self.pnp.audio_projector.to(self.device) # audio_projector_path = "ckpts/audio_projector_landscape.pth" # gate_dict_path = "ckpts/landscape.pt" # self.pnp.set_audio_projector(gate_dict_path, audio_projector_path) self.audio_projector_path = None #"ckpts/audio_projector_landscape.pth" self.adapter_ckpt_path = None #"ckpts/landscape.pt" @spaces.GPU def preprocess(self, image=None): model_key = "CompVis/stable-diffusion-v1-4" toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") toy_scheduler.set_timesteps(self.save_steps) timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=self.save_steps, strength=1.0, device=self.device) save_path = os.path.join(self.save_dir + "_forward") os.makedirs(save_path, exist_ok=True) model = Preprocess(self.device, sd_version='1.4', hf_key=None) recon_image = model.extract_latents(data_path=image, num_steps=self.steps, save_path=save_path, timesteps_to_save=timesteps_to_save, inversion_prompt=self.inversion_prompt, extract_reverse=False) T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg')) @spaces.GPU def generate(self, file=None, audio=None, prompt=None, cfg_scale=5, image_path=None, pnp_f_t=0.8, pnp_attn_t=0.8,): if self.audio_projector_path is None: #print("audio projectore path is nonee") self.audio_projector_path = "ckpts/audio_projector_landscape.pth" self.adapter_ckpt_path = "ckpts/landscape.pt" #print(f"before run_pnp {self.audio_projector_path} -- {self.adapter_ckpt_path}") image = self.pnp.run_pnp( n_timesteps=50, pnp_f_t=pnp_f_t, pnp_attn_t=pnp_attn_t, prompt=prompt, negative_prompt="", audio_path=audio, image_path=image_path, audio_projector_path = self.audio_projector_path, adapter_ckpt_path = self.adapter_ckpt_path, cfg_scale=cfg_scale, ) return image def update_audio_model(self, audio_model_update): #print(f"changing ckpts audio model {audio_model_update}") if audio_model_update == "Landscape Model": self.audio_projector_path = "ckpts/audio_projector_landscape.pth" self.adapter_ckpt_path = "ckpts/landscape.pt" else: self.audio_projector_path = "ckpts/audio_projector_gh.pth" self.adapter_ckpt_path = "ckpts/greatest_hits.pt" #print(f"audio_projector_path {self.audio_projector_path} -- {self.adapter_ckpt_path}") # self.pnp.set_audio_projector(gate_dict_path, audio_projector_path) # self.pnp.changed_model = True # gate_dict = torch.load(gate_dict_path) # for name, param in self.pnp.unet.named_parameters(): # if "adapter" in name: # param.data = gate_dict[name] # self.pnp.audio_projector.load_state_dict(torch.load(audio_projector_path)) # self.pnp.unet.to(self.device) # self.pnp.audio_projector.to(self.device) return gr.Dropdown() controller = AnimateController() def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # [SonicDiffusion: Audio-Driven Image Generation and Editing with Pretrained Diffusion Models] Adjust CFG Scale to control condition strength
Adjust PNP Injections to control amount of features injected from the source image """ ) with gr.Row(): audio_input = gr.Audio(sources="upload", type="filepath") prompt_textbox = gr.Textbox(label="Prompt", lines=2) with gr.Row(): with gr.Column(): pnp_f_t = gr.Slider(label="PNP Residual Injection", step=0.1, value=0.8, minimum=0.0, maximum=1.0) pnp_attn_t = gr.Slider(label="PNP Attention Injection", step=0.1, value=0.8, minimum=0.0, maximum=1.0) with gr.Column(): audio_model_dropdown = gr.Dropdown( label="Select SonicDiffusion model", value="Landscape Model", choices=["Landscape Model", "Greatest Hits Model"], interactive=True, ) audio_model_dropdown.change(fn=controller.update_audio_model, inputs=[audio_model_dropdown], outputs=[audio_model_dropdown]) cfg_scale_slider = gr.Slider(label="CFG Scale", step=0.5, value=7.5, minimum=0, maximum=20) with gr.Row(): preprocess_button = gr.Button(value="Preprocess", variant='primary') generate_button = gr.Button(value="Generate", variant='primary') with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image Component", sources="upload", type="filepath") with gr.Column(): output = gr.Image(label="Output Image Component", height=512, width=512) with gr.Row(): examples_img_1 = [ [Image.open("assets/corridor.png")], [Image.open("assets/desert.png")], [Image.open("assets/forest.png")], [Image.open("assets/forest_painting.png")], [Image.open("assets/golf_field.png")], [Image.open("assets/human.png")], [Image.open("assets/wood.png")], [Image.open("assets/house.png")], [Image.open("assets/apple.png")], [Image.open("assets/chair.png")], [Image.open("assets/hands.png")], [Image.open("assets/pineapple.png")], [Image.open("assets/table.png")], ] gr.Examples(examples=examples_img_1,inputs=[image_input], label="Images") # # examples_img_2 = [ # # [Image.open("assets/apple.png")], # # [Image.open("assets/chair.png")], # # [Image.open("assets/hands.png")], # # [Image.open("assets/pineapple.png")], # # [Image.open("assets/table.png")], # # ] # # gr.Examples(examples=examples,inputs=[image_input], label="Greatest Hits Images") examples2 = [ ['./assets/fire_crackling.wav'], ['./assets/forest_birds.wav'], ['./assets/forest_stepping_on_branches.wav'], ['./assets/howling_wind.wav'], ['./assets/rain.wav'], ['./assets/splashing_water.wav'], ['./assets/splashing_water_soft.wav'], ['./assets/steps_on_snow.wav'], ['./assets/thunder.wav'], ['./assets/underwater.wav'], ['./assets/waterfall_burble.wav'], ['./assets/wind_noise_birds.wav'], ] gr.Examples(examples=examples2,inputs=[audio_input], label="Landscape Audios") examples3 = [ ['./assets/cardboard.wav'], ['./assets/carpet.wav'], ['./assets/ceramic.wav'], ['./assets/cloth.wav'], ['./assets/gravel.wav'], ['./assets/leaf.wav'], ['./assets/metal.wav'], ['./assets/plastic_bag.wav'], ['./assets/plastic.wav'], ['./assets/rock.wav'], ['./assets/wood.wav'], ] gr.Examples(examples=examples3,inputs=[audio_input], label="Greatest Hits Audios") preprocess_button.click( fn=controller.preprocess, inputs=[ image_input ], outputs=output ) generate_button.click( fn=controller.generate, inputs=[ audio_model_dropdown, audio_input, prompt_textbox, cfg_scale_slider, image_input, pnp_f_t, pnp_attn_t, ], outputs=output ) return demo if __name__ == "__main__": demo = ui() demo.launch(share=True)