photoguard / utils.py
hadisalman's picture
Add demo
50b15cd
raw
history blame contribute delete
No virus
1.79 kB
import torch
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision.transforms import ToPILImage, ToTensor
totensor = ToTensor()
topil = ToPILImage()
def resize_and_crop(img, size, crop_type="center"):
'''Resize and crop the image to the given size.'''
if crop_type == "top":
center = (0, 0)
elif crop_type == "center":
center = (0.5, 0.5)
else:
raise ValueError
resize = list(size)
if size[0] is None:
resize[0] = img.size[0]
if size[1] is None:
resize[1] = img.size[1]
return ImageOps.fit(img, resize, centering=center)
def recover_image(image, init_image, mask, background=False):
image = totensor(image)
mask = totensor(mask)[0]
init_image = totensor(init_image)
if background:
result = mask * init_image + (1 - mask) * image
else:
result = mask * image + (1 - mask) * init_image
return topil(result)
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image