pix2pix-zero-01 / src /utils /edit_pipeline.py
ysharma's picture
ysharma HF staff
upload git code base
d950775
raw
history blame
8.26 kB
import pdb, sys
import numpy as np
import torch
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
sys.path.insert(0, "src/utils")
from base_pipeline import BasePipeline
from cross_attention import prep_unet
class EditingPipeline(BasePipeline):
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# pix2pix parameters
guidance_amount=0.1,
edit_dir=None,
x_in=None,
):
x_in.to(dtype=self.unet.dtype, device=self._execution_device)
# 0. modify the unet to be useful :D
self.unet = prep_unet(self.unet)
# 1. setup all caching objects
d_ref_t2attn = {} # reference cross attention maps
# 2. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# TODO: add the input checker function
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
# 3. Encode input prompt = 2x77x1024
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
# randomly sample a latent code if not provided
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
latents_init = latents.clone()
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. First Denoising loop for getting the reference cross attention maps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with torch.no_grad():
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
# add the cross attention map to the dictionary
d_ref_t2attn[t.item()] = {}
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and 'attn2' in name:
attn_mask = module.attn_probs # size is num_channel,s*s,77
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# make the reference image (reconstruction)
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
prompt_embeds_edit = prompt_embeds.clone()
#add the edit only to the second prompt, idx 0 is the negative prompt
prompt_embeds_edit[1:2] += edit_dir
latents = latents_init
# Second denoising loop for editing the text prompt
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
x_in = latent_model_input.detach().clone()
x_in.requires_grad = True
opt = torch.optim.SGD([x_in], lr=guidance_amount)
# predict the noise residual
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
loss = 0.0
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and 'attn2' in name:
curr = module.attn_probs # size is num_channel,s*s,77
ref = d_ref_t2attn[t.item()][name].detach().cuda()
loss += ((curr-ref)**2).sum((1,2)).mean(0)
loss.backward(retain_graph=False)
opt.step()
# recompute the noise
with torch.no_grad():
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
latents = x_in.detach().chunk(2)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 8. Post-processing
image = self.decode_latents(latents.detach())
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
image_edit = self.numpy_to_pil(image)
return image_rec, image_edit