|
import torch |
|
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer |
|
from omegaconf import OmegaConf |
|
import math |
|
import imageio |
|
from PIL import Image |
|
import torchvision |
|
import torch.nn.functional as F |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import time |
|
import datetime |
|
import torch |
|
import sys |
|
import os |
|
from torchvision import datasets |
|
import pickle |
|
|
|
|
|
|
|
|
|
use_half_prec = True |
|
if use_half_prec: |
|
from my_half_diffusers import AutoencoderKL, UNet2DConditionModel |
|
from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput |
|
from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler |
|
else: |
|
from my_diffusers import AutoencoderKL, UNet2DConditionModel |
|
from my_diffusers.schedulers.scheduling_utils import SchedulerOutput |
|
from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler |
|
torch_dtype = torch.float16 if use_half_prec else torch.float64 |
|
np_dtype = np.float16 if use_half_prec else np.float64 |
|
|
|
|
|
|
|
import random |
|
from tqdm.auto import tqdm |
|
from torch import autocast |
|
from difflib import SequenceMatcher |
|
|
|
|
|
model_path_clip = "openai/clip-vit-large-patch14" |
|
clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip) |
|
clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype) |
|
clip = clip_model.text_model |
|
|
|
|
|
|
|
auth_token = os.environ.get('auth_token') |
|
if auth_token is None: |
|
with open('hf_auth', 'r') as f: |
|
auth_token = f.readlines()[0].strip() |
|
model_path_diffusion = "CompVis/stable-diffusion-v1-4" |
|
|
|
unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) |
|
vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) |
|
|
|
|
|
device = 'cuda' |
|
if use_half_prec: |
|
unet.to(device) |
|
vae.to(device) |
|
clip.to(device) |
|
else: |
|
unet.double().to(device) |
|
vae.double().to(device) |
|
clip.double().to(device) |
|
print("Loaded all models") |
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from transformers import AutoFeatureExtractor |
|
|
|
safety_model_id = "CompVis/stable-diffusion-safety-checker" |
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) |
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) |
|
def load_replacement(x): |
|
try: |
|
hwc = x.shape |
|
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) |
|
y = (np.array(y)/255.0).astype(x.dtype) |
|
assert y.shape == x.shape |
|
return y |
|
except Exception: |
|
return x |
|
def check_safety(x_image): |
|
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") |
|
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) |
|
assert x_checked_image.shape[0] == len(has_nsfw_concept) |
|
for i in range(len(has_nsfw_concept)): |
|
if has_nsfw_concept[i]: |
|
|
|
x_checked_image[i] *= 0 |
|
return x_checked_image, has_nsfw_concept |
|
|
|
|
|
def EDICT_editing(im_path, |
|
base_prompt, |
|
edit_prompt, |
|
use_p2p=False, |
|
steps=50, |
|
mix_weight=0.93, |
|
init_image_strength=0.8, |
|
guidance_scale=3, |
|
run_baseline=False, |
|
width=512, height=512): |
|
""" |
|
Main call of our research, performs editing with either EDICT or DDIM |
|
|
|
Args: |
|
im_path: path to image to run on |
|
base_prompt: conditional prompt to deterministically noise with |
|
edit_prompt: desired text conditoining |
|
steps: ddim steps |
|
mix_weight: Weight of mixing layers. |
|
Higher means more consistent generations but divergence in inversion |
|
Lower means opposite |
|
This is fairly tuned and can get good results |
|
init_image_strength: Editing strength. Higher = more dramatic edit. |
|
Typically [0.6, 0.9] is good range. |
|
Definitely tunable per-image/maybe best results are at a different value |
|
guidance_scale: classifier-free guidance scale |
|
3 I've found is the best for both our method and basic DDIM inversion |
|
Higher can result in more distorted results |
|
run_baseline: |
|
VERY IMPORTANT |
|
True is EDICT, False is DDIM |
|
Output: |
|
PAIR of Images (tuple) |
|
If run_baseline=True then [0] will be edit and [1] will be original |
|
If run_baseline=False then they will be two nearly identical edited versions |
|
""" |
|
|
|
if isinstance(im_path, str): |
|
orig_im = load_im_into_format_from_path(im_path) |
|
elif Image.isImageType(im_path): |
|
width, height = im_path.size |
|
|
|
|
|
|
|
max_dim = max(width, height) |
|
if max_dim > 1024: |
|
factor = 1024 / max_dim |
|
width *= factor |
|
height *= factor |
|
width = int(width) |
|
height = int(height) |
|
im_path = im_path.resize((width, height)) |
|
|
|
min_dim = min(width, height) |
|
if min_dim < 512: |
|
factor = 512 / min_dim |
|
width *= factor |
|
height *= factor |
|
width = int(width) |
|
height = int(height) |
|
im_path = im_path.resize((width, height)) |
|
|
|
width = width - (width%64) |
|
height = height - (height%64) |
|
|
|
orig_im = im_path |
|
else: |
|
orig_im = im_path |
|
|
|
|
|
latents = coupled_stablediffusion(base_prompt, |
|
reverse=True, |
|
init_image=orig_im, |
|
init_image_strength=init_image_strength, |
|
steps=steps, |
|
mix_weight=mix_weight, |
|
guidance_scale=guidance_scale, |
|
run_baseline=run_baseline, |
|
width=width, height=height) |
|
|
|
gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt, |
|
None if (not use_p2p) else edit_prompt, |
|
fixed_starting_latent=latents, |
|
init_image_strength=init_image_strength, |
|
steps=steps, |
|
mix_weight=mix_weight, |
|
guidance_scale=guidance_scale, |
|
run_baseline=run_baseline, |
|
width=width, height=height) |
|
|
|
return gen |
|
|
|
|
|
def img2img_editing(im_path, |
|
edit_prompt, |
|
steps=50, |
|
init_image_strength=0.7, |
|
guidance_scale=3): |
|
""" |
|
Basic SDEdit/img2img, given an image add some noise and denoise with prompt |
|
""" |
|
orig_im = load_im_into_format_from_path(im_path) |
|
|
|
return baseline_stablediffusion(edit_prompt, |
|
init_image_strength=init_image_strength, |
|
steps=steps, |
|
init_image=orig_im, |
|
guidance_scale=guidance_scale) |
|
|
|
|
|
def center_crop(im): |
|
width, height = im.size |
|
min_dim = min(width, height) |
|
left = (width - min_dim)/2 |
|
top = (height - min_dim)/2 |
|
right = (width + min_dim)/2 |
|
bottom = (height + min_dim)/2 |
|
|
|
|
|
im = im.crop((left, top, right, bottom)) |
|
return im |
|
|
|
|
|
|
|
def general_crop(im, target_w, target_h): |
|
width, height = im.size |
|
min_dim = min(width, height) |
|
left = target_w / 2 |
|
top = target_h / 2 |
|
right = width - (target_w / 2) |
|
bottom = height - (target_h / 2) |
|
|
|
|
|
im = im.crop((left, top, right, bottom)) |
|
return im |
|
|
|
|
|
|
|
def load_im_into_format_from_path(im_path): |
|
return center_crop(Image.open(im_path)).resize((512,512)) |
|
|
|
|
|
|
|
def init_attention_weights(weight_tuples): |
|
tokens_length = clip_tokenizer.model_max_length |
|
weights = torch.ones(tokens_length) |
|
|
|
for i, w in weight_tuples: |
|
if i < tokens_length and i >= 0: |
|
weights[i] = w |
|
|
|
|
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn2" in name: |
|
module.last_attn_slice_weights = weights.to(device) |
|
if module_name == "CrossAttention" and "attn1" in name: |
|
module.last_attn_slice_weights = None |
|
|
|
|
|
def init_attention_edit(tokens, tokens_edit): |
|
tokens_length = clip_tokenizer.model_max_length |
|
mask = torch.zeros(tokens_length) |
|
indices_target = torch.arange(tokens_length, dtype=torch.long) |
|
indices = torch.zeros(tokens_length, dtype=torch.long) |
|
|
|
tokens = tokens.input_ids.numpy()[0] |
|
tokens_edit = tokens_edit.input_ids.numpy()[0] |
|
|
|
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): |
|
if b0 < tokens_length: |
|
if name == "equal" or (name == "replace" and a1-a0 == b1-b0): |
|
mask[b0:b1] = 1 |
|
indices[b0:b1] = indices_target[a0:a1] |
|
|
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn2" in name: |
|
module.last_attn_slice_mask = mask.to(device) |
|
module.last_attn_slice_indices = indices.to(device) |
|
if module_name == "CrossAttention" and "attn1" in name: |
|
module.last_attn_slice_mask = None |
|
module.last_attn_slice_indices = None |
|
|
|
|
|
def init_attention_func(): |
|
def new_attention(self, query, key, value, sequence_length, dim): |
|
batch_size_attention = query.shape[0] |
|
hidden_states = torch.zeros( |
|
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype |
|
) |
|
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] |
|
for i in range(hidden_states.shape[0] // slice_size): |
|
start_idx = i * slice_size |
|
end_idx = (i + 1) * slice_size |
|
attn_slice = ( |
|
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale |
|
) |
|
attn_slice = attn_slice.softmax(dim=-1) |
|
|
|
if self.use_last_attn_slice: |
|
if self.last_attn_slice_mask is not None: |
|
new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) |
|
attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask |
|
else: |
|
attn_slice = self.last_attn_slice |
|
|
|
self.use_last_attn_slice = False |
|
|
|
if self.save_last_attn_slice: |
|
self.last_attn_slice = attn_slice |
|
self.save_last_attn_slice = False |
|
|
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None: |
|
attn_slice = attn_slice * self.last_attn_slice_weights |
|
self.use_last_attn_weights = False |
|
|
|
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) |
|
|
|
hidden_states[start_idx:end_idx] = attn_slice |
|
|
|
|
|
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) |
|
return hidden_states |
|
|
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention": |
|
module.last_attn_slice = None |
|
module.use_last_attn_slice = False |
|
module.use_last_attn_weights = False |
|
module.save_last_attn_slice = False |
|
module._attention = new_attention.__get__(module, type(module)) |
|
|
|
def use_last_tokens_attention(use=True): |
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn2" in name: |
|
module.use_last_attn_slice = use |
|
|
|
def use_last_tokens_attention_weights(use=True): |
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn2" in name: |
|
module.use_last_attn_weights = use |
|
|
|
def use_last_self_attention(use=True): |
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn1" in name: |
|
module.use_last_attn_slice = use |
|
|
|
def save_last_tokens_attention(save=True): |
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn2" in name: |
|
module.save_last_attn_slice = save |
|
|
|
def save_last_self_attention(save=True): |
|
for name, module in unet.named_modules(): |
|
module_name = type(module).__name__ |
|
if module_name == "CrossAttention" and "attn1" in name: |
|
module.save_last_attn_slice = save |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def baseline_stablediffusion(prompt="", |
|
prompt_edit=None, |
|
null_prompt='', |
|
prompt_edit_token_weights=[], |
|
prompt_edit_tokens_start=0.0, |
|
prompt_edit_tokens_end=1.0, |
|
prompt_edit_spatial_start=0.0, |
|
prompt_edit_spatial_end=1.0, |
|
clip_start=0.0, |
|
clip_end=1.0, |
|
guidance_scale=7, |
|
steps=50, |
|
seed=1, |
|
width=512, height=512, |
|
init_image=None, init_image_strength=0.5, |
|
fixed_starting_latent = None, |
|
prev_image= None, |
|
grid=None, |
|
clip_guidance=None, |
|
clip_guidance_scale=1, |
|
num_cutouts=4, |
|
cut_power=1, |
|
scheduler_str='lms', |
|
return_latent=False, |
|
one_pass=False, |
|
normalize_noise_pred=False): |
|
width = width - width % 64 |
|
height = height - height % 64 |
|
|
|
|
|
if seed is None: seed = random.randrange(2**32 - 1) |
|
generator = torch.cuda.manual_seed(seed) |
|
|
|
|
|
scheduler_dict = {'ddim':DDIMScheduler, |
|
'lms':LMSDiscreteScheduler, |
|
'pndm':PNDMScheduler, |
|
'ddpm':DDPMScheduler} |
|
scheduler_call = scheduler_dict[scheduler_str] |
|
if scheduler_str == 'ddim': |
|
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
clip_sample=False, set_alpha_to_one=False) |
|
else: |
|
scheduler = scheduler_call(beta_schedule="scaled_linear", |
|
num_train_timesteps=1000) |
|
|
|
scheduler.set_timesteps(steps) |
|
if prev_image is not None: |
|
prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
num_train_timesteps=1000) |
|
prev_scheduler.set_timesteps(steps) |
|
|
|
|
|
if init_image is not None: |
|
init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) |
|
init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 |
|
init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) |
|
|
|
|
|
if init_image.shape[1] > 3: |
|
init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) |
|
|
|
|
|
init_image = init_image.to(device) |
|
|
|
|
|
with autocast(device): |
|
init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215 |
|
|
|
t_start = steps - int(steps * init_image_strength) |
|
|
|
else: |
|
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) |
|
t_start = 0 |
|
|
|
|
|
if fixed_starting_latent is None: |
|
noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype) |
|
if scheduler_str == 'ddim': |
|
if init_image is not None: |
|
raise notImplementedError |
|
latent = scheduler.add_noise(init_latent, noise, |
|
1000 - int(1000 * init_image_strength)).to(device) |
|
else: |
|
latent = noise |
|
else: |
|
latent = scheduler.add_noise(init_latent, noise, |
|
t_start).to(device) |
|
else: |
|
latent = fixed_starting_latent |
|
t_start = steps - int(steps * init_image_strength) |
|
|
|
if prev_image is not None: |
|
|
|
prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS) |
|
prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 |
|
prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2)) |
|
|
|
|
|
if prev_image.shape[1] > 3: |
|
prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:]) |
|
|
|
|
|
prev_image = prev_image.to(device) |
|
|
|
|
|
with autocast(device): |
|
prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215 |
|
|
|
t_start = steps - int(steps * init_image_strength) |
|
|
|
prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device) |
|
else: |
|
prev_latent = None |
|
|
|
|
|
|
|
with autocast(device): |
|
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) |
|
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state |
|
|
|
tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) |
|
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state |
|
|
|
|
|
assert not ((prompt_edit is not None) and (prev_image is not None)) |
|
if prompt_edit is not None: |
|
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) |
|
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state |
|
init_attention_edit(tokens_conditional, tokens_conditional_edit) |
|
elif prev_image is not None: |
|
init_attention_edit(tokens_conditional, tokens_conditional) |
|
|
|
|
|
init_attention_func() |
|
init_attention_weights(prompt_edit_token_weights) |
|
|
|
timesteps = scheduler.timesteps[t_start:] |
|
|
|
|
|
assert isinstance(guidance_scale, int) |
|
num_cycles = 1 |
|
|
|
last_noise_preds = None |
|
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): |
|
t_index = t_start + i |
|
|
|
latent_model_input = latent |
|
if scheduler_str=='lms': |
|
sigma = scheduler.sigmas[t_index] |
|
latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) |
|
else: |
|
assert scheduler_str in ['ddim', 'pndm', 'ddpm'] |
|
|
|
|
|
|
|
if len(t.shape) == 0: |
|
t = t[None].to(unet.device) |
|
noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional, |
|
).sample |
|
|
|
if prev_latent is not None: |
|
prev_latent_model_input = prev_latent |
|
prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) |
|
prev_noise_pred_uncond = unet(prev_latent_model_input, t, |
|
encoder_hidden_states=embedding_unconditional, |
|
).sample |
|
|
|
|
|
|
|
|
|
if prompt_edit is not None or prev_latent is not None: |
|
save_last_tokens_attention() |
|
save_last_self_attention() |
|
else: |
|
|
|
use_last_tokens_attention_weights() |
|
|
|
|
|
if prev_latent is not None: |
|
raise NotImplementedError |
|
prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional, |
|
).sample |
|
else: |
|
noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional, |
|
).sample |
|
|
|
|
|
t_scale = t / scheduler.num_train_timesteps |
|
if prompt_edit is not None or prev_latent is not None: |
|
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: |
|
use_last_tokens_attention() |
|
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: |
|
use_last_self_attention() |
|
|
|
|
|
use_last_tokens_attention_weights() |
|
|
|
|
|
if prompt_edit is not None: |
|
noise_pred_cond = unet(latent_model_input, t, |
|
encoder_hidden_states=embedding_conditional_edit).sample |
|
|
|
|
|
|
|
""" |
|
if cycle_i+1==num_cycles: |
|
noise_pred = noise_pred_uncond |
|
else: |
|
noise_pred = noise_pred_cond - noise_pred_uncond |
|
|
|
""" |
|
if last_noise_preds is not None: |
|
|
|
|
|
|
|
last_grad= last_noise_preds[1] - last_noise_preds[0] |
|
new_grad = noise_pred_cond - noise_pred_uncond |
|
|
|
last_noise_preds = (noise_pred_uncond, noise_pred_cond) |
|
|
|
use_cond_guidance = True |
|
if use_cond_guidance: |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
else: |
|
noise_pred = noise_pred_uncond |
|
if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end: |
|
noise_pred, latent = new_cond_fn(latent, t, t_index, |
|
embedding_conditional, noise_pred,clip_guidance, |
|
clip_guidance_scale, |
|
num_cutouts, |
|
scheduler, unet,use_cutouts=True, |
|
cut_power=cut_power) |
|
if normalize_noise_pred: |
|
noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm() |
|
if scheduler_str == 'ddim': |
|
latent = forward_step(scheduler, noise_pred, |
|
t, |
|
latent).prev_sample |
|
else: |
|
latent = scheduler.step(noise_pred, |
|
t_index, |
|
latent).prev_sample |
|
|
|
if prev_latent is not None: |
|
prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond) |
|
prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample |
|
if one_pass: break |
|
|
|
|
|
if return_latent: return latent |
|
latent = latent / 0.18215 |
|
image = vae.decode(latent.to(vae.dtype)).sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
|
image, _ = check_safety(image) |
|
|
|
image = (image[0] * 255).round().astype("uint8") |
|
return Image.fromarray(image) |
|
|
|
|
|
|
|
|
|
def get_alpha_and_beta(t, scheduler): |
|
|
|
if t.dtype==torch.long: |
|
alpha = scheduler.alphas_cumprod[t] |
|
return alpha, 1-alpha |
|
|
|
if t<0: |
|
return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod |
|
|
|
|
|
low = t.floor().long() |
|
high = t.ceil().long() |
|
rem = t - low |
|
|
|
low_alpha = scheduler.alphas_cumprod[low] |
|
high_alpha = scheduler.alphas_cumprod[high] |
|
interpolated_alpha = low_alpha * rem + high_alpha * (1-rem) |
|
interpolated_beta = 1 - interpolated_alpha |
|
return interpolated_alpha, interpolated_beta |
|
|
|
|
|
|
|
def forward_step( |
|
self, |
|
model_output, |
|
timestep: int, |
|
sample, |
|
eta: float = 0.0, |
|
use_clipped_model_output: bool = False, |
|
generator=None, |
|
return_dict: bool = True, |
|
use_double=False, |
|
) : |
|
if self.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps |
|
|
|
if timestep > self.timesteps.max(): |
|
raise NotImplementedError("Need to double check what the overflow is") |
|
|
|
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) |
|
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) |
|
|
|
|
|
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) |
|
first_term = (1./alpha_quotient) * sample |
|
second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output |
|
third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output |
|
return first_term - second_term + third_term |
|
|
|
|
|
def reverse_step( |
|
self, |
|
model_output, |
|
timestep: int, |
|
sample, |
|
eta: float = 0.0, |
|
use_clipped_model_output: bool = False, |
|
generator=None, |
|
return_dict: bool = True, |
|
use_double=False, |
|
) : |
|
if self.num_inference_steps is None: |
|
raise ValueError( |
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
|
) |
|
|
|
prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps |
|
|
|
if timestep > self.timesteps.max(): |
|
raise NotImplementedError |
|
else: |
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
|
|
alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) |
|
alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) |
|
|
|
alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) |
|
|
|
first_term = alpha_quotient * sample |
|
second_term = ((beta_prod_t)**0.5) * model_output |
|
third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output |
|
return first_term + second_term - third_term |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def latent_to_image(latent): |
|
image = vae.decode(latent.to(vae.dtype)/0.18215).sample |
|
image = prep_image_for_return(image) |
|
return image |
|
|
|
def prep_image_for_return(image): |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
image = (image[0] * 255).round().astype("uint8") |
|
image = Image.fromarray(image) |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def coupled_stablediffusion(prompt="", |
|
prompt_edit=None, |
|
null_prompt='', |
|
prompt_edit_token_weights=[], |
|
prompt_edit_tokens_start=0.0, |
|
prompt_edit_tokens_end=1.0, |
|
prompt_edit_spatial_start=0.0, |
|
prompt_edit_spatial_end=1.0, |
|
guidance_scale=7.0, steps=50, |
|
seed=1, width=512, height=512, |
|
init_image=None, init_image_strength=1.0, |
|
run_baseline=False, |
|
use_lms=False, |
|
leapfrog_steps=True, |
|
reverse=False, |
|
return_latents=False, |
|
fixed_starting_latent=None, |
|
beta_schedule='scaled_linear', |
|
mix_weight=0.93): |
|
|
|
if seed is None: seed = random.randrange(2**32 - 1) |
|
generator = torch.cuda.manual_seed(seed) |
|
|
|
def image_to_latent(im): |
|
if isinstance(im, torch.Tensor): |
|
|
|
|
|
init_latent = im.to(device) |
|
else: |
|
|
|
im = im.resize((width, height), resample=Image.Resampling.LANCZOS) |
|
im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0 |
|
|
|
if len(im.shape) < 3: |
|
im = np.stack([im for _ in range(3)], axis=2) |
|
|
|
im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2)) |
|
|
|
|
|
if im.shape[1] > 3: |
|
im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:]) |
|
|
|
|
|
im = im.to(device) |
|
|
|
if use_half_prec: |
|
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 |
|
else: |
|
with autocast(device): |
|
init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 |
|
return init_latent |
|
assert not use_lms, "Can't invert LMS the same as DDIM" |
|
if run_baseline: leapfrog_steps=False |
|
|
|
width = width - width % 64 |
|
height = height - height % 64 |
|
|
|
|
|
|
|
if init_image is not None: |
|
assert reverse |
|
|
|
if isinstance(init_image, list): |
|
if isinstance(init_image[0], torch.Tensor): |
|
init_latent = [t.clone() for t in init_image] |
|
else: |
|
init_latent = [image_to_latent(im) for im in init_image] |
|
else: |
|
init_latent = image_to_latent(init_image) |
|
|
|
t_limit = steps - int(steps * init_image_strength) |
|
else: |
|
assert not reverse, 'Need image to reverse from' |
|
init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) |
|
t_limit = 0 |
|
|
|
if reverse: |
|
latent = init_latent |
|
else: |
|
|
|
noise = torch.randn(init_latent.shape, |
|
generator=generator, |
|
device=device, |
|
dtype=torch_dtype) |
|
if fixed_starting_latent is None: |
|
latent = noise |
|
else: |
|
if isinstance(fixed_starting_latent, list): |
|
latent = [l.clone() for l in fixed_starting_latent] |
|
else: |
|
latent = fixed_starting_latent.clone() |
|
t_limit = steps - int(steps * init_image_strength) |
|
if isinstance(latent, list): |
|
latent_pair = latent |
|
else: |
|
latent_pair = [latent.clone(), latent.clone()] |
|
|
|
|
|
if steps==0: |
|
if init_image is not None: |
|
return image_to_latent(init_image) |
|
else: |
|
image = vae.decode(latent.to(vae.dtype) / 0.18215).sample |
|
return prep_image_for_return(image) |
|
|
|
|
|
schedulers = [] |
|
for i in range(2): |
|
|
|
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, |
|
beta_schedule=beta_schedule, |
|
num_train_timesteps=1000, |
|
clip_sample=False, |
|
set_alpha_to_one=False) |
|
scheduler.set_timesteps(steps) |
|
schedulers.append(scheduler) |
|
|
|
with autocast(device): |
|
|
|
tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", |
|
max_length=clip_tokenizer.model_max_length, |
|
truncation=True, return_tensors="pt", |
|
return_overflowing_tokens=True) |
|
embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state |
|
|
|
tokens_conditional = clip_tokenizer(prompt, padding="max_length", |
|
max_length=clip_tokenizer.model_max_length, |
|
truncation=True, return_tensors="pt", |
|
return_overflowing_tokens=True) |
|
embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state |
|
|
|
|
|
if prompt_edit is not None: |
|
tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", |
|
max_length=clip_tokenizer.model_max_length, |
|
truncation=True, return_tensors="pt", |
|
return_overflowing_tokens=True) |
|
embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state |
|
|
|
init_attention_edit(tokens_conditional, tokens_conditional_edit) |
|
|
|
init_attention_func() |
|
init_attention_weights(prompt_edit_token_weights) |
|
|
|
timesteps = schedulers[0].timesteps[t_limit:] |
|
if reverse: timesteps = timesteps.flip(0) |
|
|
|
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): |
|
t_scale = t / schedulers[0].num_train_timesteps |
|
|
|
if (reverse) and (not run_baseline): |
|
|
|
new_latents = [l.clone() for l in latent_pair] |
|
new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight |
|
new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight |
|
latent_pair = new_latents |
|
|
|
|
|
for latent_i in range(2): |
|
if run_baseline and latent_i==1: continue |
|
|
|
|
|
if reverse and (not run_baseline): |
|
if leapfrog_steps: |
|
|
|
orig_i = len(timesteps) - (i+1) |
|
offset = (orig_i+1) % 2 |
|
latent_i = (latent_i + offset) % 2 |
|
else: |
|
|
|
latent_i = (latent_i+1)%2 |
|
else: |
|
if leapfrog_steps: |
|
offset = i%2 |
|
latent_i = (latent_i + offset) % 2 |
|
|
|
latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i |
|
|
|
latent_model_input = latent_pair[latent_j] |
|
latent_base = latent_pair[latent_i] |
|
|
|
|
|
noise_pred_uncond = unet(latent_model_input, t, |
|
encoder_hidden_states=embedding_unconditional).sample |
|
|
|
|
|
if prompt_edit is not None: |
|
save_last_tokens_attention() |
|
save_last_self_attention() |
|
else: |
|
|
|
use_last_tokens_attention_weights() |
|
|
|
|
|
noise_pred_cond = unet(latent_model_input, t, |
|
encoder_hidden_states=embedding_conditional).sample |
|
|
|
|
|
if prompt_edit is not None: |
|
t_scale = t / schedulers[0].num_train_timesteps |
|
if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: |
|
use_last_tokens_attention() |
|
if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: |
|
use_last_self_attention() |
|
|
|
|
|
use_last_tokens_attention_weights() |
|
|
|
|
|
noise_pred_cond = unet(latent_model_input, |
|
t, |
|
encoder_hidden_states=embedding_conditional_edit).sample |
|
|
|
|
|
grad = (noise_pred_cond - noise_pred_uncond) |
|
noise_pred = noise_pred_uncond + guidance_scale * grad |
|
|
|
|
|
step_call = reverse_step if reverse else forward_step |
|
new_latent = step_call(schedulers[latent_i], |
|
noise_pred, |
|
t, |
|
latent_base) |
|
new_latent = new_latent.to(latent_base.dtype) |
|
|
|
latent_pair[latent_i] = new_latent |
|
|
|
if (not reverse) and (not run_baseline): |
|
|
|
new_latents = [l.clone() for l in latent_pair] |
|
new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone() |
|
new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone() |
|
latent_pair = new_latents |
|
|
|
|
|
if reverse or return_latents: |
|
results = [latent_pair] |
|
return results if len(results)>1 else results[0] |
|
|
|
|
|
images = [] |
|
for latent_i in range(2): |
|
latent = latent_pair[latent_i] / 0.18215 |
|
image = vae.decode(latent.to(vae.dtype)).sample |
|
images.append(image) |
|
|
|
|
|
return_arr = [] |
|
for image in images: |
|
image = prep_image_for_return(image) |
|
return_arr.append(image) |
|
results = [return_arr] |
|
return results if len(results)>1 else results[0] |
|
|
|
|
|
|