iCD-image-editing / inversion.py
dbaranchuk's picture
Main update
bf363c0 verified
raw
history blame contribute delete
No virus
4.38 kB
import torch.nn.functional as nnf
import torch
import numpy as np
from tqdm import tqdm
from torch.optim.adam import Adam
from PIL import Image
from generation import load_512
from p2p import register_attention_control
def null_optimization(solver,
latents,
guidance_scale,
num_inner_steps,
epsilon):
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
uncond_embeddings_list = []
latent_cur = latents[-1]
bar = tqdm(total=num_inner_steps * solver.n_steps)
for i in range(solver.n_steps):
uncond_embeddings = uncond_embeddings.clone().detach()
uncond_embeddings.requires_grad = True
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
latent_prev = latents[len(latents) - i - 2]
t = solver.model.scheduler.timesteps[i]
with torch.no_grad():
noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
for j in range(num_inner_steps):
noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_item = loss.item()
bar.update()
if loss_item < epsilon + i * 2e-5:
break
for j in range(j + 1, num_inner_steps):
bar.update()
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
with torch.no_grad():
context = torch.cat([uncond_embeddings, cond_embeddings])
noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
latent_cur = solver.prev_step(noise_pred, t, latent_cur)
bar.close()
return uncond_embeddings_list
def invert(solver,
stop_step,
is_cons_inversion=False,
inv_guidance_scale=1,
nti_guidance_scale=8,
dynamic_guidance=False,
tau1=0.4,
tau2=0.6,
w_embed_dim=0,
image_path=None,
prompt='',
offsets=(0, 0, 0, 0),
do_nti=False,
do_npi=False,
num_inner_steps=10,
early_stop_epsilon=1e-5,
seed=0,
):
solver.init_prompt(prompt)
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
register_attention_control(solver.model, None)
if isinstance(image_path, list):
image_gt = [load_512(path, *offsets) for path in image_path]
elif isinstance(image_path, str):
image_gt = load_512(image_path, *offsets)
else:
image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
if is_cons_inversion:
image_rec, ddim_latents = solver.cons_inversion(image_gt,
w_embed_dim=w_embed_dim,
guidance_scale=inv_guidance_scale,
seed=seed,)
else:
image_rec, ddim_latents = solver.ddim_inversion(image_gt,
n_steps=stop_step,
guidance_scale=inv_guidance_scale,
dynamic_guidance=dynamic_guidance,
tau1=tau1, tau2=tau2,
w_embed_dim=w_embed_dim)
if do_nti:
print("Null-text optimization...")
uncond_embeddings = null_optimization(solver,
ddim_latents,
nti_guidance_scale,
num_inner_steps,
early_stop_epsilon)
elif do_npi:
uncond_embeddings = [cond_embeddings] * solver.n_steps
else:
uncond_embeddings = None
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings