import torch import numpy as np import torch from PIL import Image, ImageOps from torchvision.transforms import ToPILImage, ToTensor totensor = ToTensor() topil = ToPILImage() def resize_and_crop(img, size, crop_type="center"): '''Resize and crop the image to the given size.''' if crop_type == "top": center = (0, 0) elif crop_type == "center": center = (0.5, 0.5) else: raise ValueError resize = list(size) if size[0] is None: resize[0] = img.size[0] if size[1] is None: resize[1] = img.size[1] return ImageOps.fit(img, resize, centering=center) def recover_image(image, init_image, mask, background=False): image = totensor(image) mask = totensor(mask)[0] init_image = totensor(init_image) if background: result = mask * init_image + (1 - mask) * image else: result = mask * image + (1 - mask) * init_image return topil(result) def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 def prepare_mask_and_masked_image(image, mask): image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) return mask, masked_image