# Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/preprocess.py from transformers import CLIPTextModel, CLIPTokenizer, logging from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler # suppress partial model loading warning logging.set_verbosity_error() import os from PIL import Image from tqdm import tqdm, trange import torch import torch.nn as nn import argparse from pathlib import Path from pnp_utils import * import torchvision.transforms as T def get_timesteps(scheduler, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start class Preprocess(nn.Module): def __init__(self, device, sd_version='2.0', hf_key=None): super().__init__() self.device = device self.sd_version = sd_version self.use_depth = False print(f'[INFO] loading stable diffusion...') if hf_key is not None: print(f'[INFO] using hugging face custom model key: {hf_key}') model_key = hf_key elif self.sd_version == '2.1': model_key = "stabilityai/stable-diffusion-2-1-base" elif self.sd_version == '2.0': model_key = "stabilityai/stable-diffusion-2-base" elif self.sd_version == '1.5': model_key = "runwayml/stable-diffusion-v1-5" elif self.sd_version == 'depth': model_key = "stabilityai/stable-diffusion-2-depth" self.use_depth = True elif self.sd_version == '1.4': model_key = "CompVis/stable-diffusion-v1-4" else: raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') # Create model self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16", torch_dtype=torch.float16).to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16", torch_dtype=torch.float16).to(self.device) self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16", torch_dtype=torch.float16).to(self.device) self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") print(f'[INFO] loaded stable diffusion!') self.inversion_func = self.ddim_inversion @torch.no_grad() def get_text_embeds(self, prompt, negative_prompt, device="cuda"): text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings @torch.no_grad() def decode_latents(self, latents): with torch.autocast(device_type='cuda', dtype=torch.float32): latents = 1 / 0.18215 * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def load_img(self, image_path): image_pil = T.Resize(512)(Image.open(image_path).convert("RGB")) image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device) return image @torch.no_grad() def encode_imgs(self, imgs): with torch.autocast(device_type='cuda', dtype=torch.float32): imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist latents = posterior.mean * 0.18215 return latents @torch.no_grad() def ddim_inversion(self, cond, latent, save_path, save_latents=True, timesteps_to_save=None): timesteps = reversed(self.scheduler.timesteps) with torch.autocast(device_type='cuda', dtype=torch.float32): for i, t in enumerate(tqdm(timesteps)): cond_batch = cond.repeat(latent.shape[0], 1, 1) alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[timesteps[i - 1]] if i > 0 else self.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample pred_x0 = (latent - sigma_prev * eps) / mu_prev latent = mu * pred_x0 + sigma * eps if save_latents: torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) return latent @torch.no_grad() def ddim_sample(self, x, cond, save_path, save_latents=False, timesteps_to_save=None): timesteps = self.scheduler.timesteps with torch.autocast(device_type='cuda', dtype=torch.float32): for i, t in enumerate(tqdm(timesteps)): cond_batch = cond.repeat(x.shape[0], 1, 1) alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[timesteps[i + 1]] if i < len(timesteps) - 1 else self.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 eps = self.unet(x, t, encoder_hidden_states=cond_batch).sample pred_x0 = (x - sigma * eps) / mu x = mu_prev * pred_x0 + sigma_prev * eps if save_latents: torch.save(x, os.path.join(save_path, f'noisy_latents_{t}.pt')) return x @torch.no_grad() def extract_latents(self, num_steps, data_path, save_path, timesteps_to_save, inversion_prompt='', extract_reverse=False): self.scheduler.set_timesteps(num_steps) cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0) image = self.load_img(data_path) latent = self.encode_imgs(image) inverted_x = self.inversion_func(cond, latent, save_path, save_latents=not extract_reverse, timesteps_to_save=timesteps_to_save) latent_reconstruction = self.ddim_sample(inverted_x, cond, save_path, save_latents=extract_reverse, timesteps_to_save=timesteps_to_save) rgb_reconstruction = self.decode_latents(latent_reconstruction) return rgb_reconstruction # , latent_reconstruction def run(opt): device = 'cuda' # timesteps to save if opt.sd_version == '2.1': model_key = "stabilityai/stable-diffusion-2-1-base" elif opt.sd_version == '2.0': model_key = "stabilityai/stable-diffusion-2-base" elif opt.sd_version == '1.5': model_key = "runwayml/stable-diffusion-v1-5" elif opt.sd_version == 'depth': model_key = "stabilityai/stable-diffusion-2-depth" elif opt.sd_version == '1.4': model_key = "CompVis/stable-diffusion-v1-4" toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") toy_scheduler.set_timesteps(opt.save_steps) timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt.save_steps, strength=1.0, device=device) seed_everything(opt.seed) extraction_path_prefix = "_reverse" if opt.extract_reverse else "_forward" save_path = os.path.join(opt.save_dir + extraction_path_prefix, os.path.splitext(os.path.basename(opt.data_path))[0]) os.makedirs(save_path, exist_ok=True) model = Preprocess(device, sd_version=opt.sd_version, hf_key=None) recon_image = model.extract_latents(data_path=opt.data_path, num_steps=opt.steps, save_path=save_path, timesteps_to_save=timesteps_to_save, inversion_prompt=opt.inversion_prompt, extract_reverse=opt.extract_reverse) T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg')) if __name__ == "__main__": device = 'cuda' parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='data/source_2.png') parser.add_argument('--save_dir', type=str, default='latents') parser.add_argument('--sd_version', type=str, default='1.4', choices=['1.5', '2.0', '2.1', '1.4'], help="stable diffusion version") parser.add_argument('--seed', type=int, default=1) parser.add_argument('--steps', type=int, default=50) parser.add_argument('--save-steps', type=int, default=1000) parser.add_argument('--inversion_prompt', type=str, default='') parser.add_argument('--extract-reverse', default=False, action='store_true', help="extract features during the denoising process") opt = parser.parse_args() run(opt)