import torch.nn.functional as nnf import torch import numpy as np from tqdm import tqdm from torch.optim.adam import Adam from PIL import Image from generation import load_512 from p2p import register_attention_control def null_optimization(solver, latents, guidance_scale, num_inner_steps, epsilon): uncond_embeddings, cond_embeddings = solver.context.chunk(2) uncond_embeddings_list = [] latent_cur = latents[-1] bar = tqdm(total=num_inner_steps * solver.n_steps) for i in range(solver.n_steps): uncond_embeddings = uncond_embeddings.clone().detach() uncond_embeddings.requires_grad = True optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) latent_prev = latents[len(latents) - i - 2] t = solver.model.scheduler.timesteps[i] with torch.no_grad(): noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings) for j in range(num_inner_steps): noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur) loss = nnf.mse_loss(latents_prev_rec, latent_prev) optimizer.zero_grad() loss.backward() optimizer.step() loss_item = loss.item() bar.update() if loss_item < epsilon + i * 2e-5: break for j in range(j + 1, num_inner_steps): bar.update() uncond_embeddings_list.append(uncond_embeddings[:1].detach()) with torch.no_grad(): context = torch.cat([uncond_embeddings, cond_embeddings]) noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context) latent_cur = solver.prev_step(noise_pred, t, latent_cur) bar.close() return uncond_embeddings_list def invert(solver, stop_step, is_cons_inversion=False, inv_guidance_scale=1, nti_guidance_scale=8, dynamic_guidance=False, tau1=0.4, tau2=0.6, w_embed_dim=0, image_path=None, prompt='', offsets=(0, 0, 0, 0), do_nti=False, do_npi=False, num_inner_steps=10, early_stop_epsilon=1e-5, seed=0, ): solver.init_prompt(prompt) uncond_embeddings, cond_embeddings = solver.context.chunk(2) register_attention_control(solver.model, None) if isinstance(image_path, list): image_gt = [load_512(path, *offsets) for path in image_path] elif isinstance(image_path, str): image_gt = load_512(image_path, *offsets) else: image_gt = np.array(Image.fromarray(image_path).resize((512, 512))) if is_cons_inversion: image_rec, ddim_latents = solver.cons_inversion(image_gt, w_embed_dim=w_embed_dim, guidance_scale=inv_guidance_scale, seed=seed,) else: image_rec, ddim_latents = solver.ddim_inversion(image_gt, n_steps=stop_step, guidance_scale=inv_guidance_scale, dynamic_guidance=dynamic_guidance, tau1=tau1, tau2=tau2, w_embed_dim=w_embed_dim) if do_nti: print("Null-text optimization...") uncond_embeddings = null_optimization(solver, ddim_latents, nti_guidance_scale, num_inner_steps, early_stop_epsilon) elif do_npi: uncond_embeddings = [cond_embeddings] * solver.n_steps else: uncond_embeddings = None return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings