Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from tqdm import tqdm | |
from typing import Union | |
from IPython.display import display | |
import p2p | |
# Main function to run | |
# ---------------------------------------------------------------------- | |
def runner( | |
model, | |
prompt, | |
controller, | |
solver, | |
is_cons_forward=False, | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
generator=None, | |
latent=None, | |
uncond_embeddings=None, | |
start_time=50, | |
return_type='image', | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6, | |
w_embed_dim=0, | |
): | |
p2p.register_attention_control(model, controller) | |
height = width = 512 | |
solver.init_prompt(prompt, None) | |
latent, latents = init_latent(latent, model, 512, 512, generator, len(prompt)) | |
model.scheduler.set_timesteps(num_inference_steps) | |
dynamic_guidance = True if tau1 < 1.0 or tau1 < 1.0 else False | |
if not is_cons_forward: | |
latents = solver.ddim_loop(latents, | |
num_inference_steps, | |
is_forward=False, | |
guidance_scale=guidance_scale, | |
dynamic_guidance=dynamic_guidance, | |
tau1=tau1, | |
tau2=tau2, | |
w_embed_dim=w_embed_dim, | |
uncond_embeddings=uncond_embeddings if uncond_embeddings is not None else None, | |
controller=controller) | |
latents = latents[-1] | |
else: | |
latents = solver.cons_generation( | |
latents, | |
guidance_scale=guidance_scale, | |
w_embed_dim=w_embed_dim, | |
dynamic_guidance=dynamic_guidance, | |
tau1=tau1, | |
tau2=tau2, | |
controller=controller) | |
latents = latents[-1] | |
if return_type == 'image': | |
image = latent2image(model.vae, latents.to(model.vae.dtype)) | |
else: | |
image = latents | |
return image, latent | |
# ---------------------------------------------------------------------- | |
# Utils | |
# ---------------------------------------------------------------------- | |
def linear_schedule_old(t, guidance_scale, tau1, tau2): | |
t = t / 1000 | |
if t <= tau1: | |
gamma = 1.0 | |
elif t >= tau2: | |
gamma = 0.0 | |
else: | |
gamma = (tau2 - t) / (tau2 - tau1) | |
return gamma * guidance_scale | |
def linear_schedule(t, guidance_scale, tau1=0.4, tau2=0.8): | |
t = t / 1000 | |
if t <= tau1: | |
return guidance_scale | |
if t >= tau2: | |
return 1.0 | |
gamma = (tau2 - t) / (tau2 - tau1) * (guidance_scale - 1.0) + 1.0 | |
return gamma | |
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): | |
""" | |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
Args: | |
timesteps (`torch.Tensor`): | |
generate embedding vectors at these timesteps | |
embedding_dim (`int`, *optional*, defaults to 512): | |
dimension of the embeddings to generate | |
dtype: | |
data type of the generated embeddings | |
Returns: | |
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | |
""" | |
assert len(w.shape) == 1 | |
w = w * 1000.0 | |
half_dim = embedding_dim // 2 | |
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
emb = w.to(dtype)[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = torch.nn.functional.pad(emb, (0, 1)) | |
assert emb.shape == (w.shape[0], embedding_dim) | |
return emb | |
# ---------------------------------------------------------------------- | |
# Diffusion step with scheduler from diffusers and controller for editing | |
# ---------------------------------------------------------------------- | |
def extract_into_tensor(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas): | |
sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape) | |
alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape) | |
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
# Set hard boundaries to ensure equivalence with forward (direct) CD | |
alphas_s[boundary_timesteps == 0] = 1.0 | |
sigmas_s[boundary_timesteps == 0] = 0.0 | |
if prediction_type == "epsilon": | |
pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction | |
pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step | |
elif prediction_type == "v_prediction": | |
assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment" | |
pred_x_0 = alphas * sample - sigmas * model_output | |
else: | |
raise ValueError(f"Prediction type {prediction_type} currently not supported.") | |
return pred_x_0 | |
def guided_step(noise_prediction_text, | |
noise_pred_uncond, | |
t, | |
guidance_scale, | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6): | |
if dynamic_guidance: | |
if not isinstance(t, int): | |
t = t.item() | |
new_guidance_scale = linear_schedule(t, guidance_scale, tau1=tau1, tau2=tau2) | |
else: | |
new_guidance_scale = guidance_scale | |
noise_pred = noise_pred_uncond + new_guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
return noise_pred | |
# ---------------------------------------------------------------------- | |
# DDIM scheduler with inversion | |
# ---------------------------------------------------------------------- | |
class Generator: | |
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray]): | |
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.scheduler.alphas_cumprod[ | |
prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output | |
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction | |
return prev_sample | |
def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray]): | |
timestep, next_timestep = min( | |
timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod | |
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
return next_sample | |
def get_noise_pred_single(self, latents, t, context): | |
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] | |
return noise_pred | |
def get_noise_pred(self, | |
model, | |
latent, | |
t, | |
guidance_scale=1, | |
context=None, | |
w_embed_dim=0, | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6): | |
latents_input = torch.cat([latent] * 2) | |
if context is None: | |
context = self.context | |
# w embed | |
# -------------------------------------- | |
if w_embed_dim > 0: | |
if dynamic_guidance: | |
if not isinstance(t, int): | |
t_item = t.item() | |
guidance_scale = linear_schedule_old(t_item, guidance_scale, tau1=tau1, tau2=tau2) # TODO UPDATE | |
if len(latents_input) == 4: | |
guidance_scale_tensor = torch.tensor([0.0, 0.0, 0.0, guidance_scale]) | |
else: | |
guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents_input)) | |
w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=w_embed_dim) | |
w_embedding = w_embedding.to(device=latent.device, dtype=latent.dtype) | |
else: | |
w_embedding = None | |
# -------------------------------------- | |
noise_pred = model.unet(latents_input.to(dtype=model.unet.dtype), | |
t, | |
timestep_cond=w_embedding.to(dtype=model.unet.dtype) if w_embed_dim > 0 else None, | |
encoder_hidden_states=context)["sample"] | |
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
if guidance_scale > 1 and w_embedding is None: | |
noise_pred = guided_step(noise_prediction_text, noise_pred_uncond, t, guidance_scale, dynamic_guidance, | |
tau1, tau2) | |
else: | |
noise_pred = noise_prediction_text | |
return noise_pred | |
def latent2image(self, latents, return_type='np'): | |
latents = 1 / 0.18215 * latents.detach() | |
image = self.model.vae.decode(latents.to(dtype=self.model.dtype))['sample'] | |
if return_type == 'np': | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
image = (image * 255).astype(np.uint8) | |
return image | |
def image2latent(self, image): | |
with torch.no_grad(): | |
if type(image) is Image: | |
image = np.array(image) | |
if type(image) is torch.Tensor and image.dim() == 4: | |
latents = image | |
elif type(image) is list: | |
image = [np.array(i).reshape(1, 512, 512, 3) for i in image] | |
image = np.concatenate(image) | |
image = torch.from_numpy(image).float() / 127.5 - 1 | |
image = image.permute(0, 3, 1, 2).to(self.model.device, dtype=self.model.vae.dtype) | |
latents = self.model.vae.encode(image)['latent_dist'].mean | |
latents = latents * 0.18215 | |
else: | |
image = torch.from_numpy(image).float() / 127.5 - 1 | |
image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device, dtype=self.model.dtype) | |
latents = self.model.vae.encode(image)['latent_dist'].mean | |
latents = latents * 0.18215 | |
return latents | |
def init_prompt(self, prompt, uncond_embeddings=None): | |
if uncond_embeddings is None: | |
uncond_input = self.model.tokenizer( | |
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length, | |
return_tensors="pt" | |
) | |
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] | |
text_input = self.model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] | |
self.context = torch.cat([uncond_embeddings.expand(*text_embeddings.shape), text_embeddings]) | |
self.prompt = prompt | |
def ddim_loop(self, | |
latent, | |
n_steps, | |
is_forward=True, | |
guidance_scale=1, | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6, | |
w_embed_dim=0, | |
uncond_embeddings=None, | |
controller=None): | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
for i in tqdm(range(n_steps)): | |
if uncond_embeddings is not None: | |
self.init_prompt(self.prompt, uncond_embeddings[i]) | |
if is_forward: | |
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] | |
else: | |
t = self.model.scheduler.timesteps[i] | |
noise_pred = self.get_noise_pred( | |
model=self.model, | |
latent=latent, | |
t=t, | |
context=None, | |
guidance_scale=guidance_scale, | |
dynamic_guidance=dynamic_guidance, | |
w_embed_dim=w_embed_dim, | |
tau1=tau1, | |
tau2=tau2) | |
if is_forward: | |
latent = self.next_step(noise_pred, t, latent) | |
else: | |
latent = self.prev_step(noise_pred, t, latent) | |
if controller is not None: | |
latent = controller.step_callback(latent) | |
all_latent.append(latent) | |
return all_latent | |
def scheduler(self): | |
return self.model.scheduler | |
def ddim_inversion(self, | |
image, | |
n_steps=None, | |
guidance_scale=1, | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6, | |
w_embed_dim=0): | |
if n_steps is None: | |
n_steps = self.n_steps | |
latent = self.image2latent(image) | |
image_rec = self.latent2image(latent) | |
ddim_latents = self.ddim_loop(latent, | |
is_forward=True, | |
guidance_scale=guidance_scale, | |
n_steps=n_steps, | |
dynamic_guidance=dynamic_guidance, | |
tau1=tau1, | |
tau2=tau2, | |
w_embed_dim=w_embed_dim) | |
return image_rec, ddim_latents | |
def cons_generation(self, | |
latent, | |
guidance_scale=1, | |
dynamic_guidance=False, | |
tau1=0.4, | |
tau2=0.6, | |
w_embed_dim=0, | |
controller=None, ): | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device) | |
sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device) | |
for i, (t, s) in enumerate(tqdm(zip(self.reverse_timesteps, self.reverse_boundary_timesteps))): | |
noise_pred = self.get_noise_pred( | |
model=self.reverse_cons_model, | |
latent=latent, | |
t=t.to(self.model.device), | |
context=None, | |
tau1=tau1, tau2=tau2, | |
w_embed_dim=w_embed_dim, | |
guidance_scale=guidance_scale, | |
dynamic_guidance=dynamic_guidance) | |
latent = predicted_origin( | |
noise_pred, | |
torch.tensor([t] * len(latent), device=self.model.device), | |
torch.tensor([s] * len(latent), device=self.model.device), | |
latent, | |
self.model.scheduler.config.prediction_type, | |
alpha_schedule, | |
sigma_schedule, | |
) | |
if controller is not None: | |
latent = controller.step_callback(latent) | |
all_latent.append(latent) | |
return all_latent | |
def cons_inversion(self, | |
image, | |
guidance_scale=0.0, | |
w_embed_dim=0, | |
seed=0): | |
alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device) | |
sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device) | |
# 5. Prepare latent variables | |
latent = self.image2latent(image) | |
generator = torch.Generator().manual_seed(seed) | |
noise = torch.randn(latent.shape, generator=generator).to(latent.device) | |
latent = self.noise_scheduler.add_noise(latent, noise, torch.tensor([self.start_timestep])) | |
image_rec = self.latent2image(latent) | |
for i, (t, s) in enumerate(tqdm(zip(self.forward_timesteps, self.forward_boundary_timesteps))): | |
# predict the noise residual | |
noise_pred = self.get_noise_pred( | |
model=self.forward_cons_model, | |
latent=latent, | |
t=t.to(self.model.device), | |
context=None, | |
guidance_scale=guidance_scale, | |
w_embed_dim=w_embed_dim, | |
dynamic_guidance=False) | |
latent = predicted_origin( | |
noise_pred, | |
torch.tensor([t] * len(latent), device=self.model.device), | |
torch.tensor([s] * len(latent), device=self.model.device), | |
latent, | |
self.model.scheduler.config.prediction_type, | |
alpha_schedule, | |
sigma_schedule, | |
) | |
return image_rec, [latent] | |
def _create_forward_inverse_timesteps(self, | |
num_endpoints, | |
n_steps, | |
max_inverse_timestep_index): | |
timestep_interval = n_steps // num_endpoints + int(n_steps % num_endpoints > 0) | |
endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1 | |
inverse_endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1 | |
inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index]) | |
endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist()) | |
inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs] | |
return endpoints, inverse_endpoints | |
def __init__(self, | |
model, | |
n_steps, | |
noise_scheduler, | |
forward_cons_model=None, | |
reverse_cons_model=None, | |
num_endpoints=1, | |
num_forward_endpoints=1, | |
reverse_timesteps=None, | |
forward_timesteps=None, | |
max_forward_timestep_index=49, | |
start_timestep=19): | |
self.model = model | |
self.forward_cons_model = forward_cons_model | |
self.reverse_cons_model = reverse_cons_model | |
self.noise_scheduler = noise_scheduler | |
self.n_steps = n_steps | |
self.tokenizer = self.model.tokenizer | |
self.model.scheduler.set_timesteps(n_steps) | |
self.prompt = None | |
self.context = None | |
step_ratio = 1000 // n_steps | |
self.ddim_timesteps = (np.arange(1, n_steps + 1) * step_ratio).round().astype(np.int64) - 1 | |
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
self.start_timestep = start_timestep | |
# Set endpoints for direct CTM | |
if reverse_timesteps is None or forward_timesteps is None: | |
endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_endpoints, n_steps, | |
max_forward_timestep_index) | |
self.reverse_timesteps, self.reverse_boundary_timesteps = inverse_endpoints.flip(0), endpoints.flip(0) | |
# Set endpoints for forward CTM | |
endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_forward_endpoints, n_steps, | |
max_forward_timestep_index) | |
self.forward_timesteps, self.forward_boundary_timesteps = endpoints, inverse_endpoints | |
self.forward_timesteps[0] = self.start_timestep | |
else: | |
self.reverse_timesteps, self.reverse_boundary_timesteps = reverse_timesteps, reverse_timesteps | |
self.reverse_timesteps.reverse() | |
self.reverse_boundary_timesteps = self.reverse_boundary_timesteps[1:] + [self.reverse_boundary_timesteps[0]] | |
self.reverse_boundary_timesteps[-1] = 0 | |
self.reverse_timesteps, self.reverse_boundary_timesteps = torch.tensor(reverse_timesteps), torch.tensor( | |
self.reverse_boundary_timesteps) | |
self.forward_timesteps, self.forward_boundary_timesteps = forward_timesteps, forward_timesteps | |
self.forward_boundary_timesteps = self.forward_boundary_timesteps[1:] + [self.forward_boundary_timesteps[0]] | |
self.forward_boundary_timesteps[-1] = 999 | |
self.forward_timesteps, self.forward_boundary_timesteps = torch.tensor( | |
self.forward_timesteps), torch.tensor(self.forward_boundary_timesteps) | |
print(f"Endpoints reverse CTM: {self.reverse_timesteps}, {self.reverse_boundary_timesteps}") | |
print(f"Endpoints forward CTM: {self.forward_timesteps}, {self.forward_boundary_timesteps}") | |
# ---------------------------------------------------------------------- | |
# 3rd party utils | |
# ---------------------------------------------------------------------- | |
def latent2image(vae, latents): | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents)['sample'] | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
image = (image * 255).astype(np.uint8) | |
return image | |
def init_latent(latent, model, height, width, generator, batch_size): | |
if latent is None: | |
latent = torch.randn( | |
(1, model.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) | |
return latent, latents | |
def load_512(image_path, left=0, right=0, top=0, bottom=0): | |
# if type(image_path) is str: | |
# image = np.array(Image.open(image_path))[:, :, :3] | |
# else: | |
# image = image_path | |
# h, w, c = image.shape | |
# left = min(left, w - 1) | |
# right = min(right, w - left - 1) | |
# top = min(top, h - left - 1) | |
# bottom = min(bottom, h - top - 1) | |
# image = image[top:h - bottom, left:w - right] | |
# h, w, c = image.shape | |
# if h < w: | |
# offset = (w - h) // 2 | |
# image = image[:, offset:offset + h] | |
# elif w < h: | |
# offset = (h - w) // 2 | |
# image = image[offset:offset + w] | |
image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3] | |
image = np.array(Image.fromarray(image).resize((512, 512))) | |
return image | |
def to_pil_images(images, num_rows=1, offset_ratio=0.02): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
return pil_img | |
def view_images(images, num_rows=1, offset_ratio=0.02): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
display(pil_img) | |
# ---------------------------------------------------------------------- | |