Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import transforms | |
import numpy as np | |
import skfmm | |
from PIL import Image | |
import torch.nn as nn | |
import cv2 | |
import scipy | |
from scipy.ndimage.filters import gaussian_filter | |
import kornia | |
import warnings | |
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended") | |
import alpha_clip | |
from augmentations import ImageAugmentations | |
from constants import Const, N | |
def get_dist_field(dist_from, device, as_squeezed_np=False): | |
if not isinstance(dist_from, np.ndarray): | |
dist_from = dist_from.cpu().numpy() | |
assert np.max(dist_from) <= 1 | |
dist_from = -(np.where(dist_from, 0, -1) + 0.5) | |
dist_field = skfmm.distance(dist_from, dx=1) | |
if as_squeezed_np: | |
return dist_field | |
return torch.tensor(dist_field).to(device) | |
def get_surround(surround_from, surround_width, device, as_squeezed_np=False): | |
dists = get_dist_field(surround_from, device) | |
surround = (dists <= surround_width).to(surround_from.dtype) | |
if as_squeezed_np: | |
return surround.cpu().numpy() | |
return surround | |
class DynMask: | |
def __init__(self, click_pil, args, init_image_tensor, device, total_steps): | |
self.args = args | |
self.device = device | |
self.init_image = init_image_tensor | |
self.total_steps = total_steps | |
self.ac_size = (self.args.alpha_clip_scale, self.args.alpha_clip_scale) | |
if self.args.alpha_clip_scale == 336: | |
self.ac_model, self.ac_preprocess = alpha_clip.load( | |
"ViT-L/14@336px", | |
alpha_vision_ckpt_pth="./checkpoints/clip_l14_336_grit1m_fultune_8xe.pth", | |
device=self.device, | |
) | |
else: | |
self.ac_model, self.ac_preprocess = alpha_clip.load( | |
"ViT-L/14", | |
alpha_vision_ckpt_pth="./checkpoints/clip_l14_grit20m_fultune_2xe.pth", | |
device=self.device, | |
) | |
self.image_augmentations = ImageAugmentations( | |
self.args.alpha_clip_scale, Const.AUG_NUM | |
) | |
self.text_features = self.get_text_features([self.args.prompt]) | |
self.latent_size = Const.LATENT_SIZE | |
self.decoded_size = (Const.H, Const.W) | |
self.thresh_val = Const.THRESH_VAL | |
self.base_potential = None | |
self.potential = None | |
self.latent_mask = None | |
self.set_init_masks(click_pil) | |
self.cached_masks_clones = {} | |
self.closs_hist = {} | |
self.latents_hist = {} | |
self.latent_masks_hist = {} | |
def normalize_point_size(self, click, radius_for64=1.367): | |
threshed = (click > 0.5).astype(float) | |
x, y = np.where(threshed) | |
center = int(x.mean().round()), int(y.mean().round()) | |
norm_threshed = np.zeros_like(threshed) | |
norm_threshed[center[0], center[1]] = 1 | |
norm_threshed = get_surround( | |
torch.tensor(norm_threshed).to(self.device), | |
click.shape[0] / 64 * radius_for64 - 0.3, | |
self.device, | |
as_squeezed_np=True, | |
) | |
return norm_threshed | |
def calc_potential(self, click_pil, sigma_for_shape64): | |
dest_size = self.latent_size | |
click = click_pil.convert("L").resize(dest_size, Image.NEAREST) | |
click = (np.array(click) > 125).astype(float) | |
click = self.normalize_point_size( | |
click, radius_for64=Const.POINT_ON_LATENT_RADIUS | |
) | |
potential = gaussian_filter( | |
click, sigma=sigma_for_shape64 * (click.shape[0]) / 64 | |
) | |
potential = (potential - np.min(potential)) / max( | |
np.max(potential) - np.min(potential), 1e-8 | |
) | |
potential = potential[np.newaxis, np.newaxis, ...] | |
potential = torch.from_numpy(potential).half().to(self.device) | |
return potential | |
def set_init_masks(self, click_pil, stretch_factor=1.0): | |
potential = self.calc_potential( | |
click_pil, sigma_for_shape64=Const.SIGMA_FOR_SHAPE64 | |
) | |
self.base_potential = potential.detach().to(torch.float64) | |
if self.base_potential.ndim == 2: | |
self.base_potential = self.base_potential.unsqueeze(0).unsqueeze(0) | |
self.base_potential = self.base_potential * (Const.POTENTIAL_PEAK - (-1)) - 1 | |
self.base_potential = stretch_factor * self.base_potential | |
self.set_cur_masks(step_i=0) | |
def set_cur_masks( | |
self, step_i, grads_to_update=None, surround_ring=None, return_only=None | |
): | |
potential = self.base_potential + self.get_bias(step_i) | |
if grads_to_update is not None: | |
potential = potential + (surround_ring * Const.MASK_LR * grads_to_update) | |
potential = transforms.GaussianBlur( | |
Const.GAUSS_K_MASK, sigma=Const.GAUSS_SIGMA_MASK | |
)(potential) | |
if torch.all(potential <= 0): | |
potential += Const.ADDITION_IN_COLLAPSE | |
print( | |
f"{'*' * 10} Mask shrunk entirely, added {Const.ADDITION_IN_COLLAPSE}" | |
) | |
elif torch.all(potential >= 0): | |
potential -= Const.ADDITION_IN_COLLAPSE | |
print( | |
f"{'*' * 10} Mask expanded entirely, reduced {Const.ADDITION_IN_COLLAPSE}" | |
) | |
self.potential = potential.half() | |
self.latent_mask = self.get_threshed_mask(self.potential) | |
return self.get_curr_masks(return_only=return_only) | |
def get_curr_masks(self, return_only=None): | |
if return_only is not None: | |
if return_only == N.POTENTIAL: | |
return self.potential | |
elif return_only == N.LATENT_MASK: | |
return self.latent_mask | |
else: | |
raise ValueError(f"return_only should be in ('{N.POTENTIAL}', '{N.LATENT_MASK}')") | |
return self.potential, self.latent_mask | |
def make_cached_masks_clones(self, name): | |
self.cached_masks_clones[name] = { | |
N.POTENTIAL: self.potential.detach().clone(), | |
N.LATENT_MASK: self.latent_mask.detach().clone(), | |
} | |
def set_masks_from_cached_masks_clones(self, name): | |
self.potential = self.cached_masks_clones[name][N.POTENTIAL] | |
self.latent_mask = self.cached_masks_clones[name][N.LATENT_MASK] | |
def evolve_mask( | |
self, step_i, decoder, latent_pred_z0, source_latents, return_only=None | |
): | |
potential, latent_mask = self.get_curr_masks() | |
surround_ring = self.get_ring(latent_mask) | |
grads_latent = self.calc_grads( | |
latent_pred_z0=latent_pred_z0, | |
source_latents=source_latents, | |
potential=potential, | |
step_i=step_i, | |
decoder=decoder, | |
) | |
grads_latent = torch.abs(grads_latent) | |
grads_latent = transforms.GaussianBlur( | |
Const.GAUSS_K_GRADS, sigma=Const.GAUSS_SIGMA_GRADS | |
)(grads_latent) | |
grads_latent = (grads_latent - grads_latent.mean()) / max( | |
grads_latent.std(), 1e-6 | |
) | |
grads_latent = torch.maximum(grads_latent, torch.tensor(0.0).to(self.device)) | |
self.set_cur_masks( | |
step_i=step_i, grads_to_update=grads_latent, surround_ring=surround_ring | |
) | |
return self.get_curr_masks(return_only=return_only) | |
def calc_grads(self, latent_pred_z0, source_latents, potential, step_i, decoder): | |
with torch.enable_grad(): | |
latent_mask = self.get_threshed_mask(potential) | |
latent_mask = latent_mask.detach().requires_grad_() | |
blend_predz0_origz0 = latent_pred_z0 * latent_mask + ( | |
source_latents * (1 - latent_mask) | |
) | |
scaled_blend_pred_z0_origz0 = 1 / 0.18215 * blend_predz0_origz0 | |
decoded_blend_predz0_origz0 = decoder( | |
scaled_blend_pred_z0_origz0 | |
).sample.to(torch.float32) | |
alpha_mask = transforms.Resize(self.decoded_size, interpolation=0)( | |
latent_mask | |
) | |
alpha_mask = (alpha_mask > 0.5).half().clone().detach() | |
alpha_mask = get_surround( | |
alpha_mask, | |
Const.ALPHA_MASK_DILATION_ON_512 * (Const.HW / 512.0), | |
self.device, | |
) | |
alpha_loss = self.alpha_clip_loss( | |
decoded_blend_predz0_origz0, | |
alpha_mask, | |
self.text_features, | |
self.image_augmentations, | |
augs_with_orig=True, | |
) | |
self.closs_hist[ | |
step_i - 1 | |
] = alpha_loss.detach() # The mask used for the loss is prev step mask | |
grads_latent = torch.autograd.grad(alpha_loss, latent_mask)[0].to( | |
torch.float64 | |
) | |
return grads_latent.detach() | |
def alpha_clip_loss( | |
self, | |
image, | |
mask, | |
text_features, | |
image_augmentations, | |
augs_with_orig=True, | |
return_as_similarity=False, | |
): | |
""" | |
image and mask in range 0.0 to 1.0 | |
""" | |
assert mask.min() >= 0 and mask.max() <= 1 | |
mask_transform = transforms.Compose( | |
[nn.AdaptiveAvgPool2d(self.ac_size), transforms.Normalize(0.5, 0.26)] | |
) | |
mask_normalize = transforms.Normalize(0.5, 0.26) | |
image_transform = transforms.Compose( | |
[ | |
transforms.Resize(self.ac_size, interpolation=Image.BICUBIC), | |
transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
image_normalize = transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
) | |
image = image.add(1).div(2) | |
if image.ndim == 3: | |
image = image.unsqueeze(0) | |
alpha = mask | |
if alpha.ndim == 3: | |
alpha = alpha.unsqueeze(dim=0) | |
if image_augmentations is not None: | |
image, alpha = image_augmentations(image, alpha, with_orig=augs_with_orig) | |
image = image_normalize(image).half() | |
alpha = mask_normalize(alpha).half() | |
else: | |
image = image_transform(image).half() | |
alpha = mask_transform(alpha).half() | |
image_features = self.ac_model.visual(image, alpha) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
if return_as_similarity: | |
alpha_loss = image_features @ text_features.T | |
else: | |
alpha_loss = 1 - image_features @ text_features.T | |
alpha_loss = alpha_loss.mean(dim=0) | |
return alpha_loss | |
def get_text_features(self, prompt): | |
assert type(prompt) in (list, tuple) | |
text = alpha_clip.tokenize(prompt).to(self.device) | |
text_features = self.ac_model.encode_text(text) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features | |
def get_bias(self, step_i): | |
bias = Const.BIAS_DILATION_VAL * (Const.BIAS_DILATION_DEC_FACTOR**step_i) | |
while torch.all(self.base_potential + bias > 0) and bias > 1e-8: | |
bias *= 0.9 | |
return bias | |
def get_threshed_mask(self, potential): | |
thresh_val = self.thresh_val | |
t_m = (potential > thresh_val).half() | |
t_m = t_m.cpu().numpy().squeeze().astype(np.uint8) | |
t_m = scipy.ndimage.binary_fill_holes(t_m) | |
t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half() | |
t_m = self.close_gaps_with_connection( | |
t_m, thickness=Const.CLOSE_GAPS_WITH_CONNECTION_THICKNESS | |
) | |
t_m = kornia.morphology.closing( | |
t_m, torch.ones(Const.CLOSING_K, Const.CLOSING_K).to(self.device) | |
) | |
t_m = t_m.cpu().numpy().squeeze().astype(np.uint8) | |
t_m = scipy.ndimage.binary_fill_holes(t_m) | |
t_m = torch.tensor(t_m).to(self.device).unsqueeze(0).unsqueeze(0).half() | |
t_m = transforms.GaussianBlur( | |
Const.GAUSS_K_THRESHED, sigma=Const.GAUSS_SIGMA_THRESHED | |
)(t_m) | |
t_m = (t_m > Const.THRESH_POST_GAUSS).half() | |
return t_m | |
def close_gaps_with_connection(self, threshed_mask, thickness): | |
# also cleans small contours | |
given_threshed_mask = threshed_mask | |
threshed_mask = threshed_mask.cpu().numpy().squeeze().astype(np.uint8) | |
connected_mask = threshed_mask * 0 | |
contours, hierarchy = cv2.findContours( | |
threshed_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE | |
) | |
if len(contours) == 1: | |
return given_threshed_mask | |
contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True) | |
contours = [ | |
cnt | |
for cnt in contours | |
if cv2.contourArea(cnt) | |
> threshed_mask.shape[-1] * threshed_mask.shape[-2] * 0.001 | |
] | |
cv2.drawContours(connected_mask, contours, 0, 255, -1) | |
for i in range(1, len(contours)): | |
cv2.drawContours(connected_mask, contours, i, 255, -1) | |
hull = cv2.convexHull(contours[i]) # Convex hull of contour | |
hull = cv2.approxPolyDP(hull, 0.1 * cv2.arcLength(hull, True), True) | |
connect = hull.copy() | |
for hp in hull: | |
dists = np.linalg.norm(contours[0] - hp, axis=2).squeeze() | |
min_points = np.where(dists == dists.min())[0] | |
for mp in min_points: | |
connect = np.append( | |
connect, np.expand_dims(contours[0][mp], axis=0), axis=0 | |
) | |
connected_mask = cv2.drawContours( | |
connected_mask, [connect], -1, color=255, thickness=thickness | |
) | |
connected_mask = cv2.drawContours( | |
connected_mask, [connect], -1, color=255, thickness=-1 | |
) | |
connected_mask = ( | |
((torch.tensor(connected_mask).to(self.device)) > 125) | |
.unsqueeze(0) | |
.unsqueeze(0) | |
.half() | |
) | |
return connected_mask | |
def get_plain_dilated_latent_mask( | |
self, | |
last_step_latent_mask, | |
step_i, | |
total_steps, | |
max_area_ratio_for_dilation=None, | |
rerun_dyn_start_step_i=None, | |
): | |
max_area_ratio_for_dilation = ( | |
Const.MAX_AREA_RATIO_FOR_DILATION | |
if max_area_ratio_for_dilation is None | |
else max_area_ratio_for_dilation | |
) | |
if ( | |
last_step_latent_mask.sum() | |
> max_area_ratio_for_dilation * last_step_latent_mask.nelement() | |
): | |
return last_step_latent_mask | |
first_k = self.latent_size[-1] // 2 | |
while ( | |
get_surround(last_step_latent_mask, first_k, self.device).sum() | |
> 0.75 * self.latent_size[-1] ** 2 | |
): | |
first_k -= 1 | |
if rerun_dyn_start_step_i: | |
plain_dilation_ws = np.linspace( | |
first_k, 0, rerun_dyn_start_step_i + 2 - Const.RERUN_STOP_DILATION | |
).round() | |
plain_dilation_ws = np.pad( | |
plain_dilation_ws, (0, total_steps - len(plain_dilation_ws)) | |
) | |
else: | |
plain_dilation_ws = np.array( | |
[first_k / max(1, (i / 3)) for i in range(0, total_steps)] | |
).round() | |
plain_dilation_ws[-10:] = 0 | |
return get_surround( | |
last_step_latent_mask, plain_dilation_ws[step_i], self.device | |
).half() | |
def get_ring(self, latent_mask): | |
assert (latent_mask.min() >= 0) and (latent_mask.max() <= 1) | |
out_ring_width = Const.OUT_RING_WIDTH | |
in_on_ring_width = Const.IN_ON_RING_WIDTH | |
latent_mask = (latent_mask.cpu().numpy() >= 0.5).astype(np.float16) | |
dists = get_dist_field(latent_mask, self.device, as_squeezed_np=True) | |
in_ring_width = in_on_ring_width - 1 | |
in_ring = dists.copy() | |
in_ring[in_ring > -1] = 0 | |
in_ring[in_ring <= -in_ring_width - 1] = 0 | |
in_ring[in_ring != 0] = 1 | |
on_ring = latent_mask.copy() | |
on_ring[dists < -1] = 0 | |
in_on_ring = in_ring.astype(bool) | on_ring.astype(bool) | |
out_ring = dists.copy() | |
out_ring[out_ring <= 0] = 0 | |
out_ring[out_ring > out_ring_width] = 0 | |
out_ring[out_ring != 0] = 1 | |
surround_ring = in_on_ring.astype(np.uint8) | out_ring.astype(np.uint8) | |
surround_ring = torch.tensor(surround_ring).to(self.device) | |
return surround_ring | |