Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/preprocess.py | |
from transformers import CLIPTextModel, CLIPTokenizer, logging | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
# suppress partial model loading warning | |
logging.set_verbosity_error() | |
import os | |
from PIL import Image | |
from tqdm import tqdm, trange | |
import torch | |
import torch.nn as nn | |
import argparse | |
from pathlib import Path | |
from pnp_utils import * | |
import torchvision.transforms as T | |
def get_timesteps(scheduler, num_inference_steps, strength, device): | |
# get the original timestep using init_timestep | |
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
t_start = max(num_inference_steps - init_timestep, 0) | |
timesteps = scheduler.timesteps[t_start:] | |
return timesteps, num_inference_steps - t_start | |
class Preprocess(nn.Module): | |
def __init__(self, device, sd_version='2.0', hf_key=None): | |
super().__init__() | |
self.device = device | |
self.sd_version = sd_version | |
self.use_depth = False | |
print(f'[INFO] loading stable diffusion...') | |
if hf_key is not None: | |
print(f'[INFO] using hugging face custom model key: {hf_key}') | |
model_key = hf_key | |
elif self.sd_version == '2.1': | |
model_key = "stabilityai/stable-diffusion-2-1-base" | |
elif self.sd_version == '2.0': | |
model_key = "stabilityai/stable-diffusion-2-base" | |
elif self.sd_version == '1.5': | |
model_key = "runwayml/stable-diffusion-v1-5" | |
elif self.sd_version == 'depth': | |
model_key = "stabilityai/stable-diffusion-2-depth" | |
self.use_depth = True | |
elif self.sd_version == '1.4': | |
model_key = "CompVis/stable-diffusion-v1-4" | |
else: | |
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') | |
# Create model | |
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16", | |
torch_dtype=torch.float16).to(self.device) | |
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") | |
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16", | |
torch_dtype=torch.float16).to(self.device) | |
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16", | |
torch_dtype=torch.float16).to(self.device) | |
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
print(f'[INFO] loaded stable diffusion!') | |
self.inversion_func = self.ddim_inversion | |
def get_text_embeds(self, prompt, negative_prompt, device="cuda"): | |
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
truncation=True, return_tensors='pt') | |
text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] | |
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
return_tensors='pt') | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
return text_embeddings | |
def decode_latents(self, latents): | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
latents = 1 / 0.18215 * latents | |
imgs = self.vae.decode(latents).sample | |
imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
return imgs | |
def load_img(self, image_path): | |
image_pil = T.Resize(512)(Image.open(image_path).convert("RGB")) | |
image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device) | |
return image | |
def encode_imgs(self, imgs): | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
imgs = 2 * imgs - 1 | |
posterior = self.vae.encode(imgs).latent_dist | |
latents = posterior.mean * 0.18215 | |
return latents | |
def ddim_inversion(self, cond, latent, save_path, save_latents=True, | |
timesteps_to_save=None): | |
timesteps = reversed(self.scheduler.timesteps) | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
for i, t in enumerate(tqdm(timesteps)): | |
cond_batch = cond.repeat(latent.shape[0], 1, 1) | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
self.scheduler.alphas_cumprod[timesteps[i - 1]] | |
if i > 0 else self.scheduler.final_alpha_cumprod | |
) | |
mu = alpha_prod_t ** 0.5 | |
mu_prev = alpha_prod_t_prev ** 0.5 | |
sigma = (1 - alpha_prod_t) ** 0.5 | |
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample | |
pred_x0 = (latent - sigma_prev * eps) / mu_prev | |
latent = mu * pred_x0 + sigma * eps | |
if save_latents: | |
torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
return latent | |
def ddim_sample(self, x, cond, save_path, save_latents=False, timesteps_to_save=None): | |
timesteps = self.scheduler.timesteps | |
with torch.autocast(device_type='cuda', dtype=torch.float32): | |
for i, t in enumerate(tqdm(timesteps)): | |
cond_batch = cond.repeat(x.shape[0], 1, 1) | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
self.scheduler.alphas_cumprod[timesteps[i + 1]] | |
if i < len(timesteps) - 1 | |
else self.scheduler.final_alpha_cumprod | |
) | |
mu = alpha_prod_t ** 0.5 | |
sigma = (1 - alpha_prod_t) ** 0.5 | |
mu_prev = alpha_prod_t_prev ** 0.5 | |
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
eps = self.unet(x, t, encoder_hidden_states=cond_batch).sample | |
pred_x0 = (x - sigma * eps) / mu | |
x = mu_prev * pred_x0 + sigma_prev * eps | |
if save_latents: | |
torch.save(x, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
return x | |
def extract_latents(self, num_steps, data_path, save_path, timesteps_to_save, | |
inversion_prompt='', extract_reverse=False): | |
self.scheduler.set_timesteps(num_steps) | |
cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0) | |
image = self.load_img(data_path) | |
latent = self.encode_imgs(image) | |
inverted_x = self.inversion_func(cond, latent, save_path, save_latents=not extract_reverse, | |
timesteps_to_save=timesteps_to_save) | |
latent_reconstruction = self.ddim_sample(inverted_x, cond, save_path, save_latents=extract_reverse, | |
timesteps_to_save=timesteps_to_save) | |
rgb_reconstruction = self.decode_latents(latent_reconstruction) | |
return rgb_reconstruction # , latent_reconstruction | |
def run(opt): | |
device = 'cuda' | |
# timesteps to save | |
if opt.sd_version == '2.1': | |
model_key = "stabilityai/stable-diffusion-2-1-base" | |
elif opt.sd_version == '2.0': | |
model_key = "stabilityai/stable-diffusion-2-base" | |
elif opt.sd_version == '1.5': | |
model_key = "runwayml/stable-diffusion-v1-5" | |
elif opt.sd_version == 'depth': | |
model_key = "stabilityai/stable-diffusion-2-depth" | |
elif opt.sd_version == '1.4': | |
model_key = "CompVis/stable-diffusion-v1-4" | |
toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
toy_scheduler.set_timesteps(opt.save_steps) | |
timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt.save_steps, | |
strength=1.0, | |
device=device) | |
seed_everything(opt.seed) | |
extraction_path_prefix = "_reverse" if opt.extract_reverse else "_forward" | |
save_path = os.path.join(opt.save_dir + extraction_path_prefix, os.path.splitext(os.path.basename(opt.data_path))[0]) | |
os.makedirs(save_path, exist_ok=True) | |
model = Preprocess(device, sd_version=opt.sd_version, hf_key=None) | |
recon_image = model.extract_latents(data_path=opt.data_path, | |
num_steps=opt.steps, | |
save_path=save_path, | |
timesteps_to_save=timesteps_to_save, | |
inversion_prompt=opt.inversion_prompt, | |
extract_reverse=opt.extract_reverse) | |
T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg')) | |
if __name__ == "__main__": | |
device = 'cuda' | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_path', type=str, | |
default='data/source_2.png') | |
parser.add_argument('--save_dir', type=str, default='latents') | |
parser.add_argument('--sd_version', type=str, default='1.4', choices=['1.5', '2.0', '2.1', '1.4'], | |
help="stable diffusion version") | |
parser.add_argument('--seed', type=int, default=1) | |
parser.add_argument('--steps', type=int, default=50) | |
parser.add_argument('--save-steps', type=int, default=1000) | |
parser.add_argument('--inversion_prompt', type=str, default='') | |
parser.add_argument('--extract-reverse', default=False, action='store_true', help="extract features during the denoising process") | |
opt = parser.parse_args() | |
run(opt) | |