File size: 7,992 Bytes
0305ee7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import numpy as np
from . import utils
from utils import torch_device
import matplotlib.pyplot as plt
def get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype):
"""
in_channels: often obtained with `unet.config.in_channels`
"""
# Obtain with torch.float32 and cast to float16 if needed
# Directly obtaining latents in float16 will lead to different latents
latents_base = torch.randn(
(batch_size, in_channels, height // 8, width // 8),
generator=generator, dtype=dtype
).to(torch_device, dtype=dtype)
return latents_base
def get_scaled_latents(batch_size, in_channels, height, width, generator, dtype, scheduler):
latents_base = get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype)
latents_base = latents_base * scheduler.init_noise_sigma
return latents_base
def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
"""
in_channels: often obtained with `unet.config.in_channels`
"""
assert not torch.allclose(latents_bg, latents_fg), "latents_bg should be independent with latents_fg"
dtype = latents_bg.dtype
latents = latents_bg * (1. - fg_mask) + (latents_bg * np.sqrt(1. - fg_blending_ratio) + latents_fg * np.sqrt(fg_blending_ratio)) * fg_mask
latents = latents.to(dtype=dtype)
return latents
@torch.no_grad()
def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True):
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
if latents_bg is None:
generator = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
composed_latents[0] = latents_bg
foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
mask_size = np.array([mask_tensor.sum().item() for mask_tensor in mask_tensor_list])
# Compose the largest mask first
mask_order = np.argsort(-mask_size)
if compose_box_to_bg:
# This has two functionalities:
# 1. copies the right initial latents from the right place (for centered so generation), 2. copies the right initial latents (since we have foreground blending) for centered/original so generation.
for mask_idx in mask_order:
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
# Note: need to be careful to not copy from zeros due to shifting.
mask_tensor = utils.binary_mask_to_box_mask(mask_tensor, to_device=False)
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
composed_latents[0] = composed_latents[0] * (1. - mask_tensor_expanded) + latents_all[0] * mask_tensor_expanded
# This is still needed with `compose_box_to_bg` to ensure the foreground latent is still visible and to compute foreground indices.
for mask_idx in mask_order:
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
return composed_latents, foreground_indices
def align_with_bboxes(latents_all_list, mask_tensor_list, bboxes, horizontal_shift_only=False):
"""
Each offset in `offset_list` is `(x_offset, y_offset)` (normalized).
"""
new_latents_all_list, new_mask_tensor_list, offset_list = [], [], []
for latents_all, mask_tensor, bbox in zip(latents_all_list, mask_tensor_list, bboxes):
x_src_center, y_src_center = utils.binary_mask_to_center(mask_tensor, normalize=True)
x_min_dest, y_min_dest, x_max_dest, y_max_dest = bbox
x_dest_center, y_dest_center = (x_min_dest + x_max_dest) / 2, (y_min_dest + y_max_dest) / 2
# print("src (x,y):", x_src_center, y_src_center, "dest (x,y):", x_dest_center, y_dest_center)
x_offset, y_offset = x_dest_center - x_src_center, y_dest_center - y_src_center
if horizontal_shift_only:
y_offset = 0.
offset = x_offset, y_offset
latents_all = utils.shift_tensor(latents_all, x_offset, y_offset, offset_normalized=True)
mask_tensor = utils.shift_tensor(mask_tensor, x_offset, y_offset, offset_normalized=True)
new_latents_all_list.append(latents_all)
new_mask_tensor_list.append(mask_tensor)
offset_list.append(offset)
return new_latents_all_list, new_mask_tensor_list, offset_list
@torch.no_grad()
def compose_latents_with_alignment(
model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width,
align_with_overall_bboxes=True, overall_bboxes=None, horizontal_shift_only=False, **kwargs
):
if align_with_overall_bboxes and len(latents_all_list):
expanded_overall_bboxes = utils.expand_overall_bboxes(overall_bboxes)
latents_all_list, mask_tensor_list, offset_list = align_with_bboxes(latents_all_list, mask_tensor_list, bboxes=expanded_overall_bboxes, horizontal_shift_only=horizontal_shift_only)
else:
offset_list = [(0., 0.) for _ in range(len(latents_all_list))]
composed_latents, foreground_indices = compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, **kwargs)
return composed_latents, foreground_indices, offset_list
def get_input_latents_list(model_dict, bg_seed, fg_seed_start, fg_blending_ratio, height, width, so_prompt_phrase_box_list=None, so_boxes=None, verbose=False):
"""
Note: the returned input latents are scaled by `scheduler.init_noise_sigma`
"""
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
generator_bg = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
latents_bg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_bg, dtype=dtype)
input_latents_list = []
if so_boxes is None:
# For compatibility
so_boxes = [item[-1] for item in so_prompt_phrase_box_list]
# change this changes the foreground initial noise
for idx, obj_box in enumerate(so_boxes):
H, W = height // 8, width // 8
fg_mask = utils.proportion_to_mask(obj_box, H, W)
if verbose:
plt.imshow(fg_mask.cpu().numpy())
plt.show()
fg_seed = fg_seed_start + idx
if fg_seed == bg_seed:
# We should have different seeds for foreground and background
fg_seed += 12345
generator_fg = torch.manual_seed(fg_seed)
latents_fg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_fg, dtype=dtype)
input_latents = blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=fg_blending_ratio)
input_latents = input_latents * scheduler.init_noise_sigma
input_latents_list.append(input_latents)
latents_bg = latents_bg * scheduler.init_noise_sigma
return input_latents_list, latents_bg
|