import numpy as np import cv2 from basicsr.utils import img2tensor import torch import torch.nn.functional as F def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None): h, w = image.shape[:2] w_org = image.shape[1] if resize_short_edge is not None: k = resize_short_edge / min(h, w) else: k = max_resolution / (h * w) k = k**0.5 h = int(np.round(h * k / 64)) * 64 w = int(np.round(w * k / 64)) * 64 image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) scale = w/w_org return image, scale def split_ldm(ldm): x = [] y = [] for p in ldm: x.append(p[0]) y.append(p[1]) return x,y def process_move(path_mask, h, w, dx, dy, scale, input_scale, resize_scale, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint, precision, path_mask_ref=None): dx, dy = dx*input_scale, dy*input_scale if isinstance(path_mask, str): mask_x0 = cv2.imread(path_mask) else: mask_x0 = path_mask mask_x0 = cv2.resize(mask_x0, (h, w)) if path_mask_ref is not None: if isinstance(path_mask_ref, str): mask_x0_ref = cv2.imread(path_mask_ref) else: mask_x0_ref = path_mask_ref mask_x0_ref = cv2.resize(mask_x0_ref, (h, w)) else: mask_x0_ref=None mask_x0 = img2tensor(mask_x0)[0] mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) if mask_x0_ref is not None: mask_x0_ref = img2tensor(mask_x0_ref)[0] mask_x0_ref = (mask_x0_ref>0.5).float().to('cuda', dtype=precision) mask_org = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))>0.5 mask_tar = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale*resize_scale), int(mask_x0.shape[-1]//scale*resize_scale)))>0.5 mask_cur = torch.roll(mask_tar, (int(dy//scale*resize_scale), int(dx//scale*resize_scale)), (-2,-1)) pad_size_x = abs(mask_tar.shape[-1]-mask_org.shape[-1])//2 pad_size_y = abs(mask_tar.shape[-2]-mask_org.shape[-2])//2 if resize_scale>1: sum_before = torch.sum(mask_cur) mask_cur = mask_cur[:,:,pad_size_y:pad_size_y+mask_org.shape[-2],pad_size_x:pad_size_x+mask_org.shape[-1]] sum_after = torch.sum(mask_cur) if sum_after != sum_before: raise ValueError('Resize out of bounds, exiting.') else: temp = torch.zeros(1,1,mask_org.shape[-2], mask_org.shape[-1]).to(mask_org.device) temp[:,:,pad_size_y:pad_size_y+mask_cur.shape[-2],pad_size_x:pad_size_x+mask_cur.shape[-1]]=mask_cur mask_cur =temp>0.5 mask_other = (1-((mask_cur+mask_org)>0.5).float())>0.5 mask_overlap = ((mask_cur.float()+mask_org.float())>1.5).float() mask_non_overlap = (mask_org.float()-mask_overlap)>0.5 return { "mask_x0":mask_x0, "mask_x0_ref":mask_x0_ref, "mask_tar":mask_tar, "mask_cur":mask_cur, "mask_other":mask_other, "mask_overlap":mask_overlap, "mask_non_overlap":mask_non_overlap, "up_scale":up_scale, "up_ft_index":up_ft_index, "resize_scale":resize_scale, "w_edit":w_edit, "w_content":w_content, "w_contrast":w_contrast, "w_inpaint":w_inpaint, } def process_drag_face(h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, precision): for i in range(len(x)): x[i] = int(x[i]*input_scale) y[i] = int(y[i]*input_scale) x_cur[i] = int(x_cur[i]*input_scale) y_cur[i] = int(y_cur[i]*input_scale) mask_tar = [] for p_idx in range(len(x)): mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() y_clip = int(np.clip(y[p_idx]//scale, 1, mask_i.shape[0]-2)) x_clip = int(np.clip(x[p_idx]//scale, 1, mask_i.shape[1]-2)) mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 mask_i = mask_i>0.5 mask_tar.append(mask_i) mask_cur = [] for p_idx in range(len(x_cur)): mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda() y_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_i.shape[0]-2)) x_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_i.shape[1]-2)) mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1 mask_i=mask_i>0.5 mask_cur.append(mask_i) return { "mask_tar":mask_tar, "mask_cur":mask_cur, "up_scale":up_scale, "up_ft_index":up_ft_index, "w_edit": w_edit, "w_inpaint": w_inpaint, } def process_drag(path_mask, h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, w_content, precision, latent_in): if isinstance(path_mask, str): mask_x0 = cv2.imread(path_mask) else: mask_x0 = path_mask mask_x0 = cv2.resize(mask_x0, (h, w)) mask_x0 = img2tensor(mask_x0)[0] dict_mask = {} dict_mask['base'] = mask_x0 mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision) mask_other = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))<0.5 mask_tar = [] mask_cur = [] for p_idx in range(len(x)): mask_tar_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) mask_cur_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision) y_tar_clip = int(np.clip(y[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) x_tar_clip = int(np.clip(x[p_idx]//scale, 1, mask_tar_i.shape[0]-2)) y_cur_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) x_cur_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2)) mask_tar_i[y_tar_clip-1:y_tar_clip+2,x_tar_clip-1:x_tar_clip+2]=1 mask_cur_i[y_cur_clip-1:y_cur_clip+2,x_cur_clip-1:x_cur_clip+2]=1 mask_tar_i = mask_tar_i>0.5 mask_cur_i=mask_cur_i>0.5 mask_tar.append(mask_tar_i) mask_cur.append(mask_cur_i) latent_in[:,:,y_cur_clip//up_scale-1:y_cur_clip//up_scale+2, x_cur_clip//up_scale-1:x_cur_clip//up_scale+2] = latent_in[:,:, y_tar_clip//up_scale-1:y_tar_clip//up_scale+2, x_tar_clip//up_scale-1:x_tar_clip//up_scale+2] return { "dict_mask":dict_mask, "mask_x0":mask_x0, "mask_tar":mask_tar, "mask_cur":mask_cur, "mask_other":mask_other, "up_scale":up_scale, "up_ft_index":up_ft_index, "w_edit": w_edit, "w_inpaint": w_inpaint, "w_content": w_content, "latent_in":latent_in, } def process_appearance(path_mask, path_mask_replace, h, w, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision): if isinstance(path_mask, str): mask_base = cv2.imread(path_mask) else: mask_base = path_mask mask_base = cv2.resize(mask_base, (h, w)) if isinstance(path_mask_replace, str): mask_replace = cv2.imread(path_mask_replace) else: mask_replace = path_mask_replace mask_replace = cv2.resize(mask_replace, (h, w)) dict_mask = {} mask_base = img2tensor(mask_base)[0] dict_mask['base'] = mask_base mask_base = (mask_base>0.5).to('cuda', dtype=precision) mask_replace = img2tensor(mask_replace)[0] dict_mask['replace'] = mask_replace mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) mask_base_cur = F.interpolate(mask_base[None,None], (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 mask_replace_cur = F.interpolate(mask_replace[None,None], (int(mask_replace.shape[-2]//scale), int(mask_replace.shape[-1]//scale)))>0.5 return { "dict_mask":dict_mask, "mask_base_cur":mask_base_cur, "mask_replace_cur":mask_replace_cur, "up_scale":up_scale, "up_ft_index":up_ft_index, "w_edit":w_edit, "w_content":w_content, } def process_paste(path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale=None): dx, dy = dx*input_scale, dy*input_scale if isinstance(path_mask, str): mask_base = cv2.imread(path_mask) else: mask_base = path_mask mask_base = cv2.resize(mask_base, (h, w)) dict_mask = {} mask_base = img2tensor(mask_base)[0][None, None] mask_base = (mask_base>0.5).to('cuda', dtype=precision) if resize_scale is not None and resize_scale!=1: hi, wi = mask_base.shape[-2], mask_base.shape[-1] mask_base = F.interpolate(mask_base, (int(hi*resize_scale), int(wi*resize_scale))) pad_size_x = np.abs(mask_base.shape[-1]-wi)//2 pad_size_y = np.abs(mask_base.shape[-2]-hi)//2 if resize_scale>1: mask_base = mask_base[:,:,pad_size_y:pad_size_y+hi,pad_size_x:pad_size_x+wi] else: temp = torch.zeros(1,1,hi, wi).to(mask_base.device) temp[:,:,pad_size_y:pad_size_y+mask_base.shape[-2],pad_size_x:pad_size_x+mask_base.shape[-1]]=mask_base mask_base = temp mask_replace = mask_base.clone() mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2,-1)) dict_mask['base'] = mask_base[0,0] dict_mask['replace'] = mask_replace[0,0] mask_replace = (mask_replace>0.5).to('cuda', dtype=precision) mask_base_cur = F.interpolate(mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5 mask_replace_cur = torch.roll(mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2,-1)) return { "dict_mask":dict_mask, "mask_base_cur":mask_base_cur, "mask_replace_cur":mask_replace_cur, "up_scale":up_scale, "up_ft_index":up_ft_index, "w_edit":w_edit, "w_content":w_content, "w_edit":w_edit, "w_content":w_content, }