import torch.utils.data as data import torch from PIL import Image, ImageFilter import os, cv2 import numpy as np import random from scipy.stats import norm from math import floor def random_translate(image, target): if random.random() > 0.5: image_height, image_width = image.size a = 1 b = 0 #c = 30 #left/right (i.e. 5/-5) c = int((random.random()-0.5) * 60) d = 0 e = 1 #f = 30 #up/down (i.e. 5/-5) f = int((random.random()-0.5) * 60) image = image.transform(image.size, Image.AFFINE, (a, b, c, d, e, f)) target_translate = target.copy() target_translate = target_translate.reshape(-1, 2) target_translate[:, 0] -= 1.*c/image_width target_translate[:, 1] -= 1.*f/image_height target_translate = target_translate.flatten() target_translate[target_translate < 0] = 0 target_translate[target_translate > 1] = 1 return image, target_translate else: return image, target def random_blur(image): if random.random() > 0.7: image = image.filter(ImageFilter.GaussianBlur(random.random()*5)) return image def random_occlusion(image): if random.random() > 0.5: image_np = np.array(image).astype(np.uint8) image_np = image_np[:,:,::-1] image_height, image_width, _ = image_np.shape occ_height = int(image_height*0.4*random.random()) occ_width = int(image_width*0.4*random.random()) occ_xmin = int((image_width - occ_width - 10) * random.random()) occ_ymin = int((image_height - occ_height - 10) * random.random()) image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 0] = int(random.random() * 255) image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 1] = int(random.random() * 255) image_np[occ_ymin:occ_ymin+occ_height, occ_xmin:occ_xmin+occ_width, 2] = int(random.random() * 255) image_pil = Image.fromarray(image_np[:,:,::-1].astype('uint8'), 'RGB') return image_pil else: return image def random_flip(image, target, points_flip): if random.random() > 0.5: image = image.transpose(Image.FLIP_LEFT_RIGHT) target = np.array(target).reshape(-1, 2) target = target[points_flip, :] target[:,0] = 1-target[:,0] target = target.flatten() return image, target else: return image, target def random_rotate(image, target, angle_max): if random.random() > 0.5: center_x = 0.5 center_y = 0.5 landmark_num= int(len(target) / 2) target_center = np.array(target) - np.array([center_x, center_y]*landmark_num) target_center = target_center.reshape(landmark_num, 2) theta_max = np.radians(angle_max) theta = random.uniform(-theta_max, theta_max) angle = np.degrees(theta) image = image.rotate(angle) c, s = np.cos(theta), np.sin(theta) rot = np.array(((c,-s), (s, c))) target_center_rot = np.matmul(target_center, rot) target_rot = target_center_rot.reshape(landmark_num*2) + np.array([center_x, center_y]*landmark_num) return image, target_rot else: return image, target def gen_target_pip(target, meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y): num_nb = len(meanface_indices[0]) map_channel, map_height, map_width = target_map.shape target = target.reshape(-1, 2) assert map_channel == target.shape[0] for i in range(map_channel): mu_x = int(floor(target[i][0] * map_width)) mu_y = int(floor(target[i][1] * map_height)) mu_x = max(0, mu_x) mu_y = max(0, mu_y) mu_x = min(mu_x, map_width-1) mu_y = min(mu_y, map_height-1) target_map[i, mu_y, mu_x] = 1 shift_x = target[i][0] * map_width - mu_x shift_y = target[i][1] * map_height - mu_y target_local_x[i, mu_y, mu_x] = shift_x target_local_y[i, mu_y, mu_x] = shift_y for j in range(num_nb): nb_x = target[meanface_indices[i][j]][0] * map_width - mu_x nb_y = target[meanface_indices[i][j]][1] * map_height - mu_y target_nb_x[num_nb*i+j, mu_y, mu_x] = nb_x target_nb_y[num_nb*i+j, mu_y, mu_x] = nb_y return target_map, target_local_x, target_local_y, target_nb_x, target_nb_y class ImageFolder_pip(data.Dataset): def __init__(self, root, imgs, input_size, num_lms, net_stride, points_flip, meanface_indices, transform=None, target_transform=None): self.root = root self.imgs = imgs self.num_lms = num_lms self.net_stride = net_stride self.points_flip = points_flip self.meanface_indices = meanface_indices self.num_nb = len(meanface_indices[0]) self.transform = transform self.target_transform = target_transform self.input_size = input_size def __getitem__(self, index): img_name, target = self.imgs[index] img = Image.open(os.path.join(self.root, img_name)).convert('RGB') img, target = random_translate(img, target) img = random_occlusion(img) img, target = random_flip(img, target, self.points_flip) img, target = random_rotate(img, target, 30) img = random_blur(img) target_map = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) target_local_x = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) target_local_y = np.zeros((self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) target_nb_x = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) target_nb_y = np.zeros((self.num_nb*self.num_lms, int(self.input_size/self.net_stride), int(self.input_size/self.net_stride))) target_map, target_local_x, target_local_y, target_nb_x, target_nb_y = gen_target_pip(target, self.meanface_indices, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y) target_map = torch.from_numpy(target_map).float() target_local_x = torch.from_numpy(target_local_x).float() target_local_y = torch.from_numpy(target_local_y).float() target_nb_x = torch.from_numpy(target_nb_x).float() target_nb_y = torch.from_numpy(target_nb_y).float() if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target_map = self.target_transform(target_map) target_local_x = self.target_transform(target_local_x) target_local_y = self.target_transform(target_local_y) target_nb_x = self.target_transform(target_nb_x) target_nb_y = self.target_transform(target_nb_y) return img, target_map, target_local_x, target_local_y, target_nb_x, target_nb_y def __len__(self): return len(self.imgs) if __name__ == '__main__': pass