EDICT / edict_functions.py
bram-w
safety check
93d448d
import torch
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer
from omegaconf import OmegaConf
import math
import imageio
from PIL import Image
import torchvision
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import time
import datetime
import torch
import sys
import os
from torchvision import datasets
import pickle
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
use_half_prec = True
if use_half_prec:
from my_half_diffusers import AutoencoderKL, UNet2DConditionModel
from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput
from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
else:
from my_diffusers import AutoencoderKL, UNet2DConditionModel
from my_diffusers.schedulers.scheduling_utils import SchedulerOutput
from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
torch_dtype = torch.float16 if use_half_prec else torch.float64
np_dtype = np.float16 if use_half_prec else np.float64
import random
from tqdm.auto import tqdm
from torch import autocast
from difflib import SequenceMatcher
# Build our CLIP model
model_path_clip = "openai/clip-vit-large-patch14"
clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)
clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype)
clip = clip_model.text_model
# Getting our HF Auth token
auth_token = os.environ.get('auth_token')
if auth_token is None:
with open('hf_auth', 'r') as f:
auth_token = f.readlines()[0].strip()
model_path_diffusion = "CompVis/stable-diffusion-v1-4"
# Build our SD model
unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
# Push to devices w/ double precision
device = 'cuda'
if use_half_prec:
unet.to(device)
vae.to(device)
clip.to(device)
else:
unet.double().to(device)
vae.double().to(device)
clip.double().to(device)
print("Loaded all models")
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
y = (np.array(y)/255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
return x
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
# x_checked_image[i] = load_replacement(x_checked_image[i])
x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept
def EDICT_editing(im_path,
base_prompt,
edit_prompt,
use_p2p=False,
steps=50,
mix_weight=0.93,
init_image_strength=0.8,
guidance_scale=3,
run_baseline=False,
width=512, height=512):
"""
Main call of our research, performs editing with either EDICT or DDIM
Args:
im_path: path to image to run on
base_prompt: conditional prompt to deterministically noise with
edit_prompt: desired text conditoining
steps: ddim steps
mix_weight: Weight of mixing layers.
Higher means more consistent generations but divergence in inversion
Lower means opposite
This is fairly tuned and can get good results
init_image_strength: Editing strength. Higher = more dramatic edit.
Typically [0.6, 0.9] is good range.
Definitely tunable per-image/maybe best results are at a different value
guidance_scale: classifier-free guidance scale
3 I've found is the best for both our method and basic DDIM inversion
Higher can result in more distorted results
run_baseline:
VERY IMPORTANT
True is EDICT, False is DDIM
Output:
PAIR of Images (tuple)
If run_baseline=True then [0] will be edit and [1] will be original
If run_baseline=False then they will be two nearly identical edited versions
"""
# Resize/center crop to 512x512 (Can do higher res. if desired)
if isinstance(im_path, str):
orig_im = load_im_into_format_from_path(im_path)
elif Image.isImageType(im_path):
width, height = im_path.size
# add max dim for sake of memory
max_dim = max(width, height)
if max_dim > 1024:
factor = 1024 / max_dim
width *= factor
height *= factor
width = int(width)
height = int(height)
im_path = im_path.resize((width, height))
min_dim = min(width, height)
if min_dim < 512:
factor = 512 / min_dim
width *= factor
height *= factor
width = int(width)
height = int(height)
im_path = im_path.resize((width, height))
width = width - (width%64)
height = height - (height%64)
orig_im = im_path # general_crop(im_path, width, height)
else:
orig_im = im_path
# compute latent pair (second one will be original latent if run_baseline=True)
latents = coupled_stablediffusion(base_prompt,
reverse=True,
init_image=orig_im,
init_image_strength=init_image_strength,
steps=steps,
mix_weight=mix_weight,
guidance_scale=guidance_scale,
run_baseline=run_baseline,
width=width, height=height)
# Denoise intermediate state with new conditioning
gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt,
None if (not use_p2p) else edit_prompt,
fixed_starting_latent=latents,
init_image_strength=init_image_strength,
steps=steps,
mix_weight=mix_weight,
guidance_scale=guidance_scale,
run_baseline=run_baseline,
width=width, height=height)
return gen
def img2img_editing(im_path,
edit_prompt,
steps=50,
init_image_strength=0.7,
guidance_scale=3):
"""
Basic SDEdit/img2img, given an image add some noise and denoise with prompt
"""
orig_im = load_im_into_format_from_path(im_path)
return baseline_stablediffusion(edit_prompt,
init_image_strength=init_image_strength,
steps=steps,
init_image=orig_im,
guidance_scale=guidance_scale)
def center_crop(im):
width, height = im.size # Get dimensions
min_dim = min(width, height)
left = (width - min_dim)/2
top = (height - min_dim)/2
right = (width + min_dim)/2
bottom = (height + min_dim)/2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def general_crop(im, target_w, target_h):
width, height = im.size # Get dimensions
min_dim = min(width, height)
left = target_w / 2 # (width - min_dim)/2
top = target_h / 2 # (height - min_dim)/2
right = width - (target_w / 2) # (width + min_dim)/2
bottom = height - (target_h / 2) # (height + min_dim)/2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def load_im_into_format_from_path(im_path):
return center_crop(Image.open(im_path)).resize((512,512))
#### P2P STUFF ####
def init_attention_weights(weight_tuples):
tokens_length = clip_tokenizer.model_max_length
weights = torch.ones(tokens_length)
for i, w in weight_tuples:
if i < tokens_length and i >= 0:
weights[i] = w
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn2" in name:
module.last_attn_slice_weights = weights.to(device)
if module_name == "CrossAttention" and "attn1" in name:
module.last_attn_slice_weights = None
def init_attention_edit(tokens, tokens_edit):
tokens_length = clip_tokenizer.model_max_length
mask = torch.zeros(tokens_length)
indices_target = torch.arange(tokens_length, dtype=torch.long)
indices = torch.zeros(tokens_length, dtype=torch.long)
tokens = tokens.input_ids.numpy()[0]
tokens_edit = tokens_edit.input_ids.numpy()[0]
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
if b0 < tokens_length:
if name == "equal" or (name == "replace" and a1-a0 == b1-b0):
mask[b0:b1] = 1
indices[b0:b1] = indices_target[a0:a1]
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn2" in name:
module.last_attn_slice_mask = mask.to(device)
module.last_attn_slice_indices = indices.to(device)
if module_name == "CrossAttention" and "attn1" in name:
module.last_attn_slice_mask = None
module.last_attn_slice_indices = None
def init_attention_func():
def new_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1)
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.save_last_attn_slice = False
module._attention = new_attention.__get__(module, type(module))
def use_last_tokens_attention(use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn2" in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn2" in name:
module.use_last_attn_weights = use
def use_last_self_attention(use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn1" in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn2" in name:
module.save_last_attn_slice = save
def save_last_self_attention(save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention" and "attn1" in name:
module.save_last_attn_slice = save
####################################
##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3
@torch.no_grad()
def baseline_stablediffusion(prompt="",
prompt_edit=None,
null_prompt='',
prompt_edit_token_weights=[],
prompt_edit_tokens_start=0.0,
prompt_edit_tokens_end=1.0,
prompt_edit_spatial_start=0.0,
prompt_edit_spatial_end=1.0,
clip_start=0.0,
clip_end=1.0,
guidance_scale=7,
steps=50,
seed=1,
width=512, height=512,
init_image=None, init_image_strength=0.5,
fixed_starting_latent = None,
prev_image= None,
grid=None,
clip_guidance=None,
clip_guidance_scale=1,
num_cutouts=4,
cut_power=1,
scheduler_str='lms',
return_latent=False,
one_pass=False,
normalize_noise_pred=False):
width = width - width % 64
height = height - height % 64
#If seed is None, randomly select seed from 0 to 2^32-1
if seed is None: seed = random.randrange(2**32 - 1)
generator = torch.cuda.manual_seed(seed)
#Set inference timesteps to scheduler
scheduler_dict = {'ddim':DDIMScheduler,
'lms':LMSDiscreteScheduler,
'pndm':PNDMScheduler,
'ddpm':DDPMScheduler}
scheduler_call = scheduler_dict[scheduler_str]
if scheduler_str == 'ddim':
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False, set_alpha_to_one=False)
else:
scheduler = scheduler_call(beta_schedule="scaled_linear",
num_train_timesteps=1000)
scheduler.set_timesteps(steps)
if prev_image is not None:
prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000)
prev_scheduler.set_timesteps(steps)
#Preprocess image if it exists (img2img)
if init_image is not None:
init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
if init_image.shape[1] > 3:
init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
#Move image to GPU
init_image = init_image.to(device)
#Encode image
with autocast(device):
init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215
t_start = steps - int(steps * init_image_strength)
else:
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
t_start = 0
#Generate random normal noise
if fixed_starting_latent is None:
noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype)
if scheduler_str == 'ddim':
if init_image is not None:
raise notImplementedError
latent = scheduler.add_noise(init_latent, noise,
1000 - int(1000 * init_image_strength)).to(device)
else:
latent = noise
else:
latent = scheduler.add_noise(init_latent, noise,
t_start).to(device)
else:
latent = fixed_starting_latent
t_start = steps - int(steps * init_image_strength)
if prev_image is not None:
#Resize and prev_image for numpy b h w c -> torch b c h w
prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS)
prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2))
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
if prev_image.shape[1] > 3:
prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:])
#Move image to GPU
prev_image = prev_image.to(device)
#Encode image
with autocast(device):
prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215
t_start = steps - int(steps * init_image_strength)
prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device)
else:
prev_latent = None
#Process clip
with autocast(device):
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
#Process prompt editing
assert not ((prompt_edit is not None) and (prev_image is not None))
if prompt_edit is not None:
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
init_attention_edit(tokens_conditional, tokens_conditional_edit)
elif prev_image is not None:
init_attention_edit(tokens_conditional, tokens_conditional)
init_attention_func()
init_attention_weights(prompt_edit_token_weights)
timesteps = scheduler.timesteps[t_start:]
# print(timesteps)
assert isinstance(guidance_scale, int)
num_cycles = 1 # guidance_scale + 1
last_noise_preds = None
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
t_index = t_start + i
latent_model_input = latent
if scheduler_str=='lms':
sigma = scheduler.sigmas[t_index] # last is first and first is last
latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
else:
assert scheduler_str in ['ddim', 'pndm', 'ddpm']
#Predict the unconditional noise residual
if len(t.shape) == 0:
t = t[None].to(unet.device)
noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional,
).sample
if prev_latent is not None:
prev_latent_model_input = prev_latent
prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
prev_noise_pred_uncond = unet(prev_latent_model_input, t,
encoder_hidden_states=embedding_unconditional,
).sample
# noise_pred_uncond = unet(latent_model_input, t,
# encoder_hidden_states=embedding_unconditional)['sample']
#Prepare the Cross-Attention layers
if prompt_edit is not None or prev_latent is not None:
save_last_tokens_attention()
save_last_self_attention()
else:
#Use weights on non-edited prompt when edit is None
use_last_tokens_attention_weights()
#Predict the conditional noise residual and save the cross-attention layer activations
if prev_latent is not None:
raise NotImplementedError # I totally lost track of what this is
prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional,
).sample
else:
noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional,
).sample
#Edit the Cross-Attention layer activations
t_scale = t / scheduler.num_train_timesteps
if prompt_edit is not None or prev_latent is not None:
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
use_last_tokens_attention()
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
use_last_self_attention()
#Use weights on edited prompt
use_last_tokens_attention_weights()
#Predict the edited conditional noise residual using the cross-attention masks
if prompt_edit is not None:
noise_pred_cond = unet(latent_model_input, t,
encoder_hidden_states=embedding_conditional_edit).sample
#Perform guidance
# if i%(num_cycles)==0: # cycle_i+1==num_cycles:
"""
if cycle_i+1==num_cycles:
noise_pred = noise_pred_uncond
else:
noise_pred = noise_pred_cond - noise_pred_uncond
"""
if last_noise_preds is not None:
# print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum())
# print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0),
# F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0))
last_grad= last_noise_preds[1] - last_noise_preds[0]
new_grad = noise_pred_cond - noise_pred_uncond
# print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0))
last_noise_preds = (noise_pred_uncond, noise_pred_cond)
use_cond_guidance = True
if use_cond_guidance:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_uncond
if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end:
noise_pred, latent = new_cond_fn(latent, t, t_index,
embedding_conditional, noise_pred,clip_guidance,
clip_guidance_scale,
num_cutouts,
scheduler, unet,use_cutouts=True,
cut_power=cut_power)
if normalize_noise_pred:
noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm()
if scheduler_str == 'ddim':
latent = forward_step(scheduler, noise_pred,
t,
latent).prev_sample
else:
latent = scheduler.step(noise_pred,
t_index,
latent).prev_sample
if prev_latent is not None:
prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond)
prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample
if one_pass: break
#scale and decode the image latents with vae
if return_latent: return latent
latent = latent / 0.18215
image = vae.decode(latent.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image, _ = check_safety(image)
image = (image[0] * 255).round().astype("uint8")
return Image.fromarray(image)
####################################
#### HELPER FUNCTIONS FOR OUR METHOD #####
def get_alpha_and_beta(t, scheduler):
# want to run this for both current and previous timnestep
if t.dtype==torch.long:
alpha = scheduler.alphas_cumprod[t]
return alpha, 1-alpha
if t<0:
return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod
low = t.floor().long()
high = t.ceil().long()
rem = t - low
low_alpha = scheduler.alphas_cumprod[low]
high_alpha = scheduler.alphas_cumprod[high]
interpolated_alpha = low_alpha * rem + high_alpha * (1-rem)
interpolated_beta = 1 - interpolated_alpha
return interpolated_alpha, interpolated_beta
# A DDIM forward step function
def forward_step(
self,
model_output,
timestep: int,
sample,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
use_double=False,
) :
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
if timestep > self.timesteps.max():
raise NotImplementedError("Need to double check what the overflow is")
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
first_term = (1./alpha_quotient) * sample
second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output
third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output
return first_term - second_term + third_term
# A DDIM reverse step function, the inverse of above
def reverse_step(
self,
model_output,
timestep: int,
sample,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
use_double=False,
) :
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
if timestep > self.timesteps.max():
raise NotImplementedError
else:
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
first_term = alpha_quotient * sample
second_term = ((beta_prod_t)**0.5) * model_output
third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output
return first_term + second_term - third_term
@torch.no_grad()
def latent_to_image(latent):
image = vae.decode(latent.to(vae.dtype)/0.18215).sample
image = prep_image_for_return(image)
return image
def prep_image_for_return(image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image[0] * 255).round().astype("uint8")
image = Image.fromarray(image)
return image
#############################
##### MAIN EDICT FUNCTION #######
# Use EDICT_editing to perform calls
@torch.no_grad()
def coupled_stablediffusion(prompt="",
prompt_edit=None,
null_prompt='',
prompt_edit_token_weights=[],
prompt_edit_tokens_start=0.0,
prompt_edit_tokens_end=1.0,
prompt_edit_spatial_start=0.0,
prompt_edit_spatial_end=1.0,
guidance_scale=7.0, steps=50,
seed=1, width=512, height=512,
init_image=None, init_image_strength=1.0,
run_baseline=False,
use_lms=False,
leapfrog_steps=True,
reverse=False,
return_latents=False,
fixed_starting_latent=None,
beta_schedule='scaled_linear',
mix_weight=0.93):
#If seed is None, randomly select seed from 0 to 2^32-1
if seed is None: seed = random.randrange(2**32 - 1)
generator = torch.cuda.manual_seed(seed)
def image_to_latent(im):
if isinstance(im, torch.Tensor):
# assume it's the latent
# used to avoid clipping new generation before inversion
init_latent = im.to(device)
else:
#Resize and transpose for numpy b h w c -> torch b c h w
im = im.resize((width, height), resample=Image.Resampling.LANCZOS)
im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0
# check if black and white
if len(im.shape) < 3:
im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels
im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2))
#If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
if im.shape[1] > 3:
im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:])
#Move image to GPU
im = im.to(device)
#Encode image
if use_half_prec:
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
else:
with autocast(device):
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
return init_latent
assert not use_lms, "Can't invert LMS the same as DDIM"
if run_baseline: leapfrog_steps=False
#Change size to multiple of 64 to prevent size mismatches inside model
width = width - width % 64
height = height - height % 64
#Preprocess image if it exists (img2img)
if init_image is not None:
assert reverse # want to be performing deterministic noising
# can take either pair (output of generative process) or single image
if isinstance(init_image, list):
if isinstance(init_image[0], torch.Tensor):
init_latent = [t.clone() for t in init_image]
else:
init_latent = [image_to_latent(im) for im in init_image]
else:
init_latent = image_to_latent(init_image)
# this is t_start for forward, t_end for reverse
t_limit = steps - int(steps * init_image_strength)
else:
assert not reverse, 'Need image to reverse from'
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
t_limit = 0
if reverse:
latent = init_latent
else:
#Generate random normal noise
noise = torch.randn(init_latent.shape,
generator=generator,
device=device,
dtype=torch_dtype)
if fixed_starting_latent is None:
latent = noise
else:
if isinstance(fixed_starting_latent, list):
latent = [l.clone() for l in fixed_starting_latent]
else:
latent = fixed_starting_latent.clone()
t_limit = steps - int(steps * init_image_strength)
if isinstance(latent, list): # initializing from pair of images
latent_pair = latent
else: # initializing from noise
latent_pair = [latent.clone(), latent.clone()]
if steps==0:
if init_image is not None:
return image_to_latent(init_image)
else:
image = vae.decode(latent.to(vae.dtype) / 0.18215).sample
return prep_image_for_return(image)
#Set inference timesteps to scheduler
schedulers = []
for i in range(2):
# num_raw_timesteps = max(1000, steps)
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
beta_schedule=beta_schedule,
num_train_timesteps=1000,
clip_sample=False,
set_alpha_to_one=False)
scheduler.set_timesteps(steps)
schedulers.append(scheduler)
with autocast(device):
# CLIP Text Embeddings
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length",
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors="pt",
return_overflowing_tokens=True)
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
tokens_conditional = clip_tokenizer(prompt, padding="max_length",
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors="pt",
return_overflowing_tokens=True)
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
#Process prompt editing (if running Prompt-to-Prompt)
if prompt_edit is not None:
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length",
max_length=clip_tokenizer.model_max_length,
truncation=True, return_tensors="pt",
return_overflowing_tokens=True)
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
init_attention_edit(tokens_conditional, tokens_conditional_edit)
init_attention_func()
init_attention_weights(prompt_edit_token_weights)
timesteps = schedulers[0].timesteps[t_limit:]
if reverse: timesteps = timesteps.flip(0)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
t_scale = t / schedulers[0].num_train_timesteps
if (reverse) and (not run_baseline):
# Reverse mixing layer
new_latents = [l.clone() for l in latent_pair]
new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight
new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight
latent_pair = new_latents
# alternate EDICT steps
for latent_i in range(2):
if run_baseline and latent_i==1: continue # just have one sequence for baseline
# this modifies latent_pair[i] while using
# latent_pair[(i+1)%2]
if reverse and (not run_baseline):
if leapfrog_steps:
# what i would be from going other way
orig_i = len(timesteps) - (i+1)
offset = (orig_i+1) % 2
latent_i = (latent_i + offset) % 2
else:
# Do 1 then 0
latent_i = (latent_i+1)%2
else:
if leapfrog_steps:
offset = i%2
latent_i = (latent_i + offset) % 2
latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i
latent_model_input = latent_pair[latent_j]
latent_base = latent_pair[latent_i]
#Predict the unconditional noise residual
noise_pred_uncond = unet(latent_model_input, t,
encoder_hidden_states=embedding_unconditional).sample
#Prepare the Cross-Attention layers
if prompt_edit is not None:
save_last_tokens_attention()
save_last_self_attention()
else:
#Use weights on non-edited prompt when edit is None
use_last_tokens_attention_weights()
#Predict the conditional noise residual and save the cross-attention layer activations
noise_pred_cond = unet(latent_model_input, t,
encoder_hidden_states=embedding_conditional).sample
#Edit the Cross-Attention layer activations
if prompt_edit is not None:
t_scale = t / schedulers[0].num_train_timesteps
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
use_last_tokens_attention()
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
use_last_self_attention()
#Use weights on edited prompt
use_last_tokens_attention_weights()
#Predict the edited conditional noise residual using the cross-attention masks
noise_pred_cond = unet(latent_model_input,
t,
encoder_hidden_states=embedding_conditional_edit).sample
#Perform guidance
grad = (noise_pred_cond - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * grad
step_call = reverse_step if reverse else forward_step
new_latent = step_call(schedulers[latent_i],
noise_pred,
t,
latent_base)# .prev_sample
new_latent = new_latent.to(latent_base.dtype)
latent_pair[latent_i] = new_latent
if (not reverse) and (not run_baseline):
# Mixing layer (contraction) during generative process
new_latents = [l.clone() for l in latent_pair]
new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone()
new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone()
latent_pair = new_latents
#scale and decode the image latents with vae, can return latents instead of images
if reverse or return_latents:
results = [latent_pair]
return results if len(results)>1 else results[0]
# decode latents to iamges
images = []
for latent_i in range(2):
latent = latent_pair[latent_i] / 0.18215
image = vae.decode(latent.to(vae.dtype)).sample
images.append(image)
# Return images
return_arr = []
for image in images:
image = prep_image_for_return(image)
return_arr.append(image)
results = [return_arr]
return results if len(results)>1 else results[0]