"""SAMPLING ONLY.""" import os import torch from torch import nn import torchvision import numpy as np from tqdm import tqdm from functools import partial from PIL import Image import shutil from ldm.modules.diffusionmodules.util import ( make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, ) import clip from einops import rearrange import random class VGGPerceptualLoss(torch.nn.Module): def __init__(self, resize=True): super(VGGPerceptualLoss, self).__init__() blocks = [] blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) for bl in blocks: for p in bl.parameters(): p.requires_grad = False self.blocks = torch.nn.ModuleList(blocks) self.transform = torch.nn.functional.interpolate self.resize = resize self.register_buffer( "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) self.register_buffer( "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): input = (input - self.mean) / self.std target = (target - self.mean) / self.std if self.resize: input = self.transform( input, mode="bilinear", size=(224, 224), align_corners=False ) target = self.transform( target, mode="bilinear", size=(224, 224), align_corners=False ) loss = 0.0 x = input y = target for i, block in enumerate(self.blocks): x = block(x) y = block(y) if i in feature_layers: loss += torch.nn.functional.l1_loss(x, y) if i in style_layers: act_x = x.reshape(x.shape[0], x.shape[1], -1) act_y = y.reshape(y.shape[0], y.shape[1], -1) gram_x = act_x @ act_x.permute(0, 2, 1) gram_y = act_y @ act_y.permute(0, 2, 1) loss += torch.nn.functional.l1_loss(gram_x, gram_y) return loss class DCLIPLoss(torch.nn.Module): def __init__(self): super(DCLIPLoss, self).__init__() self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") self.upsample = torch.nn.Upsample(scale_factor=7) self.avg_pool = torch.nn.AvgPool2d(kernel_size=16) def forward(self, image1, image2, text1, text2): text1 = clip.tokenize([text1]).to("cuda") text2 = clip.tokenize([text2]).to("cuda") image1 = image1.unsqueeze(0).cuda() image2 = image2.unsqueeze(0) image1 = self.avg_pool(self.upsample(image1)) image2 = self.avg_pool(self.upsample(image2)) image1_feat = self.model.encode_image(image1) image2_feat = self.model.encode_image(image2) text1_feat = self.model.encode_text(text1) text2_feat = self.model.encode_text(text2) d_image_feat = image1_feat - image2_feat d_text_feat = text1_feat - text2_feat similarity = torch.nn.CosineSimilarity()(d_image_feat, d_text_feat) return 1 - similarity class PLMSSampler(object): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) def make_schedule( self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True ): if ddim_eta != 0: raise ValueError("ddim_eta must be 0 for PLMS") self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose, ) alphas_cumprod = self.model.alphas_cumprod assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer("betas", to_torch(self.model.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer( "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) ) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) ) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), ) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=0.0, verbose=verbose, ) self.register_buffer("ddim_sigmas", ddim_sigmas) self.register_buffer("ddim_alphas", ddim_alphas) self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) ) self.register_buffer( "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps ) @torch.no_grad() def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... dynamic_threshold=None, **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) print(f'Data shape for PLMS sampling is {size}') samples, intermediates = self.plms_sampling(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) return samples, intermediates @torch.no_grad() def plms_sampling( self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {"x_inter": [img], "pred_x0": [img]} time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) if mask is not None: assert x0 is not None # import ipdb; ipdb.set_trace() img_orig = self.model.q_sample( x0, ts ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img outs = self.p_sample_plms( img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % 1 == 0 or index == total_steps - 1: intermediates["x_inter"].append(img) intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() def p_sample_plms( self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, t_next=None, ): b, *_, device = *x.shape, x.device def get_model_output(x, t): if ( unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) if isinstance(c, dict): c_in = {key: [torch.cat([unconditional_conditioning[key][0], c[key][0]])] for key in c} else: c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) if score_corrector is not None: assert self.model.parameterization == "eps" e_t = score_corrector.modify_score( self.model, e_t, x, t, c, **corrector_kwargs ) return e_t alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = ( self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas ) def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 e_t = get_model_output(x, t) if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) e_t_next = get_model_output(x_prev, t_next) e_t_prime = (e_t + e_t_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (3 * e_t - old_eps[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = ( 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) return x_prev, pred_x0, e_t ###### Above are original stable-diffusion code ############ ###### Encode Image ######################################## @torch.no_grad() def sample_encode_save_noise( self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): assert conditioning is not None # assert not isinstance(conditioning, dict) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) if verbose: print(f"Data shape for PLMS sampling is {size}") samples, intermediates, x0_loop = self.plms_sampling_enc_save_noise( conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, input_image=input_image, noise_save_path=noise_save_path, verbose=verbose ) return samples, intermediates, x0_loop @torch.no_grad() def plms_sampling_enc_save_noise( self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, verbose=True, ): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {"x_inter": [img], "pred_x0": [img]} time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) time_range = list(range(0, timesteps)) if ddim_use_original_steps else timesteps total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] if verbose: print(f"Running PLMS Sampling with {total_steps} timesteps") iterator = tqdm(time_range[:-1], desc='PLMS Sampler', total=total_steps) else: iterator = time_range[:-1] old_eps = [] noise_images = [] for each_time in time_range: noised_image = self.model.q_sample( input_image, torch.tensor([each_time]).to(device) ) noise_images.append(noised_image) # torch.save(noised_image, noise_save_path + "_image_time%d.pt" % (each_time)) # import pudb; pudb.set_trace() x0_loop = input_image.clone() alphas = ( self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas ) alphas_prev = ( self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas ) def get_model_output(x, t): x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) if isinstance(cond, dict): c_in = {key: [torch.cat([unconditional_conditioning[key][0], cond[key][0]])] for key in cond} else: c_in = torch.cat([unconditional_conditioning, cond]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) return e_t def get_x_prev_and_pred_x0(e_t, index, curr_x0): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (curr_x0 - sqrt_one_minus_at * e_t) / a_t.sqrt() a_t = torch.full((b, 1, 1, 1), alphas[index + 1], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index + 1], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index + 1], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index + 1], device=device ) dir_xt = (1.0 - a_t - sigma_t ** 2).sqrt() * e_t x_prev = a_t.sqrt() * pred_x0 + dir_xt return x_prev, pred_x0 for i, step in enumerate(iterator): index = i ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) e_t = get_model_output(x0_loop, ts) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index, x0_loop) x0_loop = x_prev # torch.save(x0_loop, noise_save_path + "_final_latent.pt") # Reconstruction img = x0_loop.clone() time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] if verbose: print(f"Running PLMS Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps, miniters=total_steps+1, mininterval=600) else: iterator = time_range old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) if mask is not None: assert x0 is not None img_orig = self.model.q_sample( x0, ts ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img outs = self.p_sample_plms_dec_save_noise( img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, input_image=input_image, noise_save_path=noise_save_path, noise_image=noise_images.pop(), ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates["x_inter"].append(img) intermediates["pred_x0"].append(pred_x0) return img, intermediates, x0_loop @torch.no_grad() def p_sample_plms_dec_save_noise( self, x, c1, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, t_next=None, input_image=None, noise_save_path=None, noise_image=None, ): b, *_, device = *x.shape, x.device def get_model_output(x, t): if ( unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): e_t = self.model.apply_model(x, t, c1) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) if isinstance(c1, dict): c_in = {key: [torch.cat([unconditional_conditioning[key][0], c1[key][0]])] for key in c1} else: c_in = torch.cat([unconditional_conditioning, c1]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) return e_t alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = ( self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas ) def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t time_curr = index * 20 + 1 # img_prev = torch.load(noise_save_path + "_image_time%d.pt" % (time_curr)) img_prev = noise_image noise = img_prev - a_prev.sqrt() * pred_x0 - dir_xt # torch.save(noise, noise_save_path + "_time%d.pt" % (time_curr)) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 e_t = get_model_output(x, t) if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) e_t_next = get_model_output(x_prev, t_next) e_t_prime = (e_t + e_t_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (3 * e_t - old_eps[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = ( 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) return x_prev, pred_x0, e_t ################## Encode Image End ############################### def p_sample_plms_sampling( self, x, c1, c2, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, t_next=None, input_image=None, optimizing_weight=None, noise_save_path=None, ): b, *_, device = *x.shape, x.device def optimize_model_output(x, t): # weight_for_pencil = torch.nn.Sigmoid()(optimizing_weight) # condition = weight_for_pencil * c1 + (1 - weight_for_pencil) * c2 condition = optimizing_weight * c1 + (1 - optimizing_weight) * c2 if ( unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): e_t = self.model.apply_model(x, t, condition) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) if isinstance(condition, dict): c_in = {key: [torch.cat([unconditional_conditioning[key][0], condition[key][0]])] for key in condition} else: c_in = torch.cat([unconditional_conditioning, condition]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) return e_t alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = ( self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas ) def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t time_curr = index * 20 + 1 if noise_save_path and index > 16: noise = torch.load(noise_save_path + "_time%d.pt" % (time_curr))[:1] else: noise = torch.zeros_like(dir_xt) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 e_t = optimize_model_output(x, t) if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) # e_t_next = get_model_output(x_prev, t_next) e_t_next = optimize_model_output(x_prev, t_next) e_t_prime = (e_t + e_t_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (3 * e_t - old_eps[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = ( 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) return x_prev, pred_x0, e_t ################## Edit Input Image ############################### def sample_optimize_intrinsic_edit( self, S, batch_size, shape, conditioning1=None, conditioning2=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, lambda_t=None, lambda_save_path=None, image_save_path=None, original_text=None, new_text=None, otext=None, noise_saved_path=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): assert conditioning1 is not None assert conditioning2 is not None self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) print(f"Data shape for PLMS sampling is {size}") self.plms_sampling_optimize_intrinsic_edit( conditioning1, conditioning2, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, input_image=input_image, noise_save_path=noise_save_path, lambda_t=lambda_t, lambda_save_path=lambda_save_path, image_save_path=image_save_path, original_text=original_text, new_text=new_text, otext=otext, noise_saved_path=noise_saved_path, ) return None def plms_sampling_optimize_intrinsic_edit( self, cond1, cond2, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, lambda_t=None, lambda_save_path=None, image_save_path=None, original_text=None, new_text=None, otext=None, noise_saved_path=None, ): # Different from above, the intrinsic edit version needs device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T img_clone = img.clone() if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {"x_inter": [img], "pred_x0": [img]} time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) weighting_parameter = lambda_t weighting_parameter.requires_grad = True from torch import optim optimizer = optim.Adam([weighting_parameter], lr=0.05) print("Original image") with torch.no_grad(): img = img_clone.clone() total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] iterator = time_range old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) outs = self.p_sample_plms_sampling( img, cond1, cond2, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, input_image=input_image, optimizing_weight=torch.ones(50)[i], noise_save_path=noise_saved_path, ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) img_temp = self.model.decode_first_stage(img) img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2) # save image with torch.no_grad(): x_sample = 255.0 * rearrange( img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" ) imgsave = Image.fromarray(x_sample.astype(np.uint8)) imgsave.save(image_save_path + "original.png") readed_image = ( torchvision.io.read_image(image_save_path + "original.png").float() / 255 ) print("Optimizing start") for epoch in tqdm(range(10)): img = img_clone.clone() total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] iterator = time_range old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) outs = self.p_sample_plms_sampling( img, cond1, cond2, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, input_image=input_image, optimizing_weight=weighting_parameter[i], noise_save_path=noise_saved_path, ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) img_temp = self.model.decode_first_stage(img) img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) img_temp_ddim = img_temp_ddim.cpu() # save image # with torch.no_grad(): # x_sample = 255.0 * rearrange( # img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" # ) # imgsave = Image.fromarray(x_sample.astype(np.uint8)) # imgsave.save(image_save_path + "/%d.png" % (epoch)) loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image) loss2 = DCLIPLoss()( readed_image, img_temp_ddim[0].float().cuda(), otext, new_text ) loss = 0.05 * loss1 + loss2 optimizer.zero_grad() loss.backward() optimizer.step() # torch.save( # weighting_parameter, lambda_save_path + "/weightingParam%d.pt" % (epoch) # ) if epoch < 9: del img else: # save image with torch.no_grad(): x_sample = 255.0 * rearrange( img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" ) imgsave = Image.fromarray(x_sample.astype(np.uint8)) imgsave.save(image_save_path + "/final.png") torch.save( weighting_parameter, lambda_save_path + "/weightingParam_final.pt" ) torch.cuda.empty_cache() # shutil.rmtree("noise") return None ################ Edit Image End ###################### ################ Disentangle ######################### def sample_optimize_intrinsic( self, S, batch_size, shape, conditioning1=None, conditioning2=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, lambda_t=None, lambda_save_path=None, image_save_path=None, original_text=None, new_text=None, otext=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): assert conditioning1 is not None assert conditioning2 is not None self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) print(f"Data shape for PLMS sampling is {size}") self.plms_sampling_optimize_intrinsic( conditioning1, conditioning2, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, input_image=input_image, noise_save_path=noise_save_path, lambda_t=lambda_t, lambda_save_path=lambda_save_path, image_save_path=image_save_path, original_text=original_text, new_text=new_text, otext=otext, ) return None def plms_sampling_optimize_intrinsic( self, cond1, cond2, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, input_image=None, noise_save_path=None, lambda_t=None, lambda_save_path=None, image_save_path=None, original_text=None, new_text=None, otext=None, ): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T img_clone = img.clone() if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) weighting_parameter = lambda_t weighting_parameter.requires_grad = True from torch import optim optimizer = optim.Adam([weighting_parameter], lr=0.05) print("Original image") with torch.no_grad(): img = img_clone.clone() total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] iterator = time_range old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) outs = self.p_sample_plms_sampling( img, cond1, cond2, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, input_image=input_image, optimizing_weight=torch.ones(50)[i], noise_save_path=noise_save_path, ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) img_temp = self.model.decode_first_stage(img) del img img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2) # save image with torch.no_grad(): x_sample = 255.0 * rearrange( img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" ) imgsave = Image.fromarray(x_sample.astype(np.uint8)) imgsave.save(image_save_path + "original.png") readed_image = ( torchvision.io.read_image(image_save_path + "original.png").float() / 255 ) print("Optimizing start") for epoch in tqdm(range(10)): img = img_clone.clone() total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] iterator = time_range old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) outs = self.p_sample_plms_sampling( img, cond1, cond2, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, input_image=input_image, optimizing_weight=weighting_parameter[i], noise_save_path=noise_save_path, ) img, _, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) img_temp = self.model.decode_first_stage(img) del img img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0) img_temp_ddim = img_temp_ddim.cpu() # # save image # with torch.no_grad(): # x_sample = 255. * rearrange(img_temp_ddim[0].detach().cpu().numpy(), 'c h w -> h w c') # imgsave = Image.fromarray(x_sample.astype(np.uint8)) # imgsave.save(image_save_path + "/%d.png"%(epoch)) loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image) loss2 = DCLIPLoss()( readed_image, img_temp_ddim[0].float().cuda(), otext, new_text ) loss = ( 0.05 * loss1 + loss2 ) # 0.05 or 0.03. Adjust according to attributes on scenes or people. optimizer.zero_grad() loss.backward() optimizer.step() # torch.save(weighting_parameter, lambda_save_path+"/weightingParam%d.pt"%(epoch)) with torch.no_grad(): if epoch == 9: # save image x_sample = 255.0 * rearrange( img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c" ) imgsave = Image.fromarray(x_sample.astype(np.uint8)) imgsave.save(image_save_path + "/final.png") torch.save( weighting_parameter, lambda_save_path + "/weightingParam_final.pt", ) torch.cuda.empty_cache() return None ################ Disentangle End #########################