Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from PIL import Image | |
import torch | |
import random | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
def apply_joint_transforms(rgb, mask, img_size, img_aug=True, test=True): | |
if test: | |
extra_pad = 16 | |
else: | |
extra_pad = random.randint(0, 32) | |
W_img, H_img = rgb.size[:2] | |
max_HW = max(H_img, W_img) | |
top_pad = (max_HW - H_img) // 2 | |
bottom_pad = max_HW - H_img - top_pad | |
left_pad = (max_HW - W_img) // 2 | |
right_pad = max_HW - W_img - left_pad | |
# 1. padding | |
rgb = TF.pad(rgb, (left_pad, top_pad, right_pad, bottom_pad), fill=255) | |
mask = TF.pad(mask, (left_pad, top_pad, right_pad, bottom_pad), fill=0) | |
if img_aug and (not test): | |
# 2. random rotate | |
if random.random() < 0.1: | |
angle = random.uniform(-10, 10) | |
rgb = TF.rotate(rgb, angle, fill=255) | |
mask = TF.rotate(mask, angle, fill=0) | |
# 3. random crop | |
if random.random() < 0.1: | |
crop_ratio = random.uniform(0.9, 1.0) | |
crop_size = int(max_HW * crop_ratio) | |
i, j, h, w = transforms.RandomCrop.get_params(rgb, (crop_size, crop_size)) | |
rgb = TF.crop(rgb, i, j, h, w) | |
mask = TF.crop(mask, i, j, h, w) | |
# 4. resize | |
target_size = (img_size, img_size) | |
rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) | |
# 5. extra padding | |
rgb = TF.pad(rgb, extra_pad, fill=255) | |
mask = TF.pad(mask, extra_pad, fill=0) | |
rgb = TF.resize(rgb, target_size, interpolation=TF.InterpolationMode.BILINEAR) | |
mask = TF.resize(mask, target_size, interpolation=TF.InterpolationMode.NEAREST) | |
# to tensor | |
rgb_tensor = TF.to_tensor(rgb) | |
mask_tensor = TF.to_tensor(mask) | |
return rgb_tensor, mask_tensor | |
def crop_recenter(image_no_bg, thereshold=100): | |
image_no_bg_np = np.array(image_no_bg) | |
mask = (image_no_bg_np[..., -1]).astype(np.uint8) | |
mask_bin = mask > thereshold | |
H, W = image_no_bg_np.shape[:2] | |
valid_pixels = mask_bin.astype(np.float32).nonzero() # [N, 2] | |
if np.sum(mask_bin) < (H*W) * 0.001: | |
min_h =0 | |
max_h = H - 1 | |
min_w = 0 | |
max_w = W -1 | |
else: | |
min_h, max_h = valid_pixels[0].min(), valid_pixels[0].max() | |
min_w, max_w = valid_pixels[1].min(), valid_pixels[1].max() | |
if min_h < 0: | |
min_h = 0 | |
if min_w < 0: | |
min_w = 0 | |
if max_h > H: | |
max_h = H | |
if max_w > W: | |
max_w = W | |
image_no_bg_np = image_no_bg_np[min_h:max_h+1, min_w:max_w+1] | |
return image_no_bg_np | |
def preprocess_image(img): | |
if isinstance(img, str): | |
img = Image.open(img) | |
img = np.array(img) | |
elif isinstance(img, Image.Image): | |
img = np.array(img) | |
if img.shape[-1] == 3: | |
mask = np.ones_like(img[..., 0:1]) | |
img = np.concatenate([img, mask], axis=-1) | |
img = crop_recenter(img, thereshold=0) / 255. | |
mask = img[..., 3] | |
img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) | |
img = Image.fromarray((img * 255).astype(np.uint8)) | |
mask = Image.fromarray((mask * 255).astype(np.uint8)) | |
img, mask = apply_joint_transforms(img, mask, img_size=518, | |
img_aug=False, test=True) | |
img = torch.cat([img, mask], dim=0) | |
return img | |