# Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/pnp.py import spaces import glob import os from pathlib import Path import torch import torch.nn as nn import torchvision.transforms as T import argparse from PIL import Image import yaml from tqdm import tqdm from transformers import logging from diffusers import DDIMScheduler, StableDiffusionPipeline from pnp_utils import * from unet2d_custom import UNet2DConditionModel from pipeline_stable_diffusion_custom import StableDiffusionPipeline from ldm.modules.encoders.audio_projector_res import Adapter # suppress partial model loading warning logging.set_verbosity_error() from diffusers import logging logging.set_verbosity_error() class PNP(nn.Module): def __init__(self, sd_version="1.4", n_timesteps=50, audio_projector_ckpt_path="ckpts/audio_projector_gh.pth", adapter_ckpt_path="ckpts/greatest_hits.pt", device="cuda", clap_path="CLAP/msclap", clap_weights = "ckpts/CLAP_weights_2022.pth", ): super().__init__() self.device = device if sd_version == '2.1': model_key = "stabilityai/stable-diffusion-2-1-base" elif sd_version == '2.0': model_key = "stabilityai/stable-diffusion-2-base" elif sd_version == '1.5': model_key = "runwayml/stable-diffusion-v1-5" elif sd_version == '1.4': model_key = "CompVis/stable-diffusion-v1-4" print(f"model key is {model_key}") else: raise ValueError(f'Stable-diffusion version {sd_version} not supported.') # Create SD models print('Loading SD model') pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda") model_id = "CompVis/stable-diffusion-v1-4" self.unet = UNet2DConditionModel.from_pretrained( model_id, subfolder="unet", use_adapter_list=[False, True, True], low_cpu_mem_usage=False, device_map=None ).to("cuda") self.audio_projector_path = "ckpts/audio_projector_landscape.pth" self.adapter_ckpt_path = "ckpts/landscape.pt" #self.pnp.set_audio_projector(gate_dict_path, audio_projector_path) # gate_dict = torch.load(adapter_ckpt_path) # for name, param in self.unet.named_parameters(): # if "adapter" in name: # param.data = gate_dict[name] #unet.to(self.device); #pipe.unet = unet.to(self.device); self.vae = pipe.vae self.tokenizer = pipe.tokenizer self.text_encoder = pipe.text_encoder # self.unet = unet.to(self.device); #pipe.unet self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") self.scheduler.set_timesteps(n_timesteps, device=self.device) self.latents_path = "latents_forward" self.output_path = "PNP-results/home" import os os.makedirs(self.output_path, exist_ok=True) import sys sys.path.append(clap_path) from CLAPWrapper import CLAPWrapper self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=True) self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).cuda() #self.audio_projector_ckpt_path = audio_projector_ckpt_path self.sr = 44100 # self.set_audio_projector(adapter_ckpt_path, audio_projector_ckpt_path) self.text_encoder = self.text_encoder.cuda() #self.audio_projector.load_state_dict(torch.load(audio_projector_path)) self.audio_projector_ckpt_path = audio_projector_ckpt_path self.adapter_ckpt_path = adapter_ckpt_path self.changed_model = False @spaces.GPU def set_audio_projector(self, adapter_ckpt_path, audio_projector_ckpt_path): #print(f"SETTING MODEL TO {adapter_ckpt_path}") gate_dict = torch.load(adapter_ckpt_path) for name, param in self.unet.named_parameters(): if "adapter" in name: param.data = gate_dict[name] self.unet.eval() self.unet = self.unet.cuda() self.audio_projector.load_state_dict(torch.load(audio_projector_ckpt_path)) self.audio_projector.eval() self.audio_projector = self.audio_projector.cuda() @spaces.GPU def set_text_embeds(self, prompt, negative_prompt=""): self.text_encoder = self.text_encoder.cuda() self.text_embeds = self.get_text_embeds(prompt, negative_prompt) self.pnp_guidance_embeds = self.get_text_embeds("", "").chunk(2)[0] @spaces.GPU def set_audio_context(self, audio_path): self.audio_projector = self.audio_projector.cuda() self.audio_encoder.clap.audio_encoder = self.audio_encoder.clap.audio_encoder.to("cuda") audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample = self.sr) dtpye_w = self.audio_projector.audio_emb_projection[0].weight.dtype device_w = self.audio_projector.audio_emb_projection[0].weight.device audio_emb = audio_emb.cuda() audio_proj = self.audio_projector(audio_emb.unsqueeze(1)) audio_emb = torch.zeros(1, 1024).cuda() audio_uc = self.audio_projector(audio_emb.unsqueeze(1)) self.audio_context = torch.cat([audio_uc, audio_uc, audio_proj]).cuda() @torch.no_grad() @spaces.GPU def get_text_embeds(self, prompt, negative_prompt, batch_size=1): # Tokenize text and get embeddings text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') input_ids = text_input.input_ids.to("cuda") text_embeddings = self.text_encoder(input_ids)[0] # Do the same for unconditional embeddings uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # Cat for final embeddings text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size) return text_embeddings @torch.no_grad() @spaces.GPU def decode_latent(self, latent): self.vae = self.vae.cuda() with torch.autocast(device_type='cuda', dtype=torch.float32): latent = 1 / 0.18215 * latent img = self.vae.decode(latent).sample img = (img / 2 + 0.5).clamp(0, 1) return img #@torch.autocast(device_type='cuda', dtype=torch.float32) @spaces.GPU def get_data(self, image_path): self.image_path = image_path # load image image = Image.open(image_path).convert('RGB') image = image.resize((512, 512), resample=Image.Resampling.LANCZOS) image = T.ToTensor()(image).to(self.device) # get noise latents_path = os.path.join(self.latents_path, f'noisy_latents_{self.scheduler.timesteps[0]}.pt') noisy_latent = torch.load(latents_path).to(self.device) return image, noisy_latent @torch.no_grad() @spaces.GPU def denoise_step(self, x, t, guidance_scale): # register the time step and features in pnp injection modules source_latents = load_source_latents_t(t, os.path.join(self.latents_path)) latent_model_input = torch.cat([source_latents] + ([x] * 2)) register_time(self, t.item()) # compute text embeddings text_embed_input = torch.cat([self.pnp_guidance_embeds, self.text_embeds], dim=0) # apply the denoising network noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input, audio_context=self.audio_context)['sample'] # perform guidance _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # compute the denoising step with the reference model denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample'] return denoised_latent @spaces.GPU def init_pnp(self, conv_injection_t, qk_injection_t): self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else [] self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else [] register_attention_control_efficient(self, self.qk_injection_timesteps) register_conv_control_efficient(self, self.conv_injection_timesteps) @spaces.GPU def run_pnp(self, n_timesteps=50, pnp_f_t=0.5, pnp_attn_t=0.5, prompt="", negative_prompt="", audio_path="", image_path="", audio_projector_path = "ckpts/audio_projector_landscape.pth", adapter_ckpt_path = "ckpts/landscape.pt", cfg_scale=5): # if not self.changed_model: #print(f"inside run_pnp {audio_projector_path}, {adapter_ckpt_path}") self.set_audio_projector(adapter_ckpt_path, audio_projector_path) self.audio_projector = self.audio_projector.cuda() self.set_text_embeds(prompt) self.set_audio_context(audio_path=audio_path) self.image, self.eps = self.get_data(image_path=image_path) self.unet = self.unet.cuda() pnp_f_t = int(n_timesteps * pnp_f_t) pnp_attn_t = int(n_timesteps * pnp_attn_t) self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t) edited_img = self.sample_loop(self.eps, cfg_scale=cfg_scale) return T.ToPILImage()(edited_img[0]) @spaces.GPU def sample_loop(self, x, cfg_scale): with torch.autocast(device_type='cuda', dtype=torch.float32): for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")): x = self.denoise_step(x, t, cfg_scale) decoded_latent = self.decode_latent(x) T.ToPILImage()(decoded_latent[0]).save(f'{self.output_path}/output.png') return decoded_latent if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config_path', type=str, default='config_pnp.yaml') opt = parser.parse_args() with open(opt.config_path, "r") as f: config = yaml.safe_load(f) os.makedirs(config["output_path"], exist_ok=True) with open(os.path.join(config["output_path"], "config.yaml"), "w") as f: yaml.dump(config, f) seed_everything(config["seed"]) print(config) pnp = PNP(config) temp = pnp.run_pnp()