import math import random from PIL import Image import blobfile as bf #from mpi4py import MPI import numpy as np from torch.utils.data import DataLoader, Dataset import os import torchvision.transforms as transforms import torch as th from functools import partial import cv2 def get_params( size, resize_size, crop_size): w, h = size new_h = h new_w = w ss, ls = min(w, h), max(w, h) # shortside and longside width_is_shorter = w == ss ls = int(resize_size * ls / ss) ss = resize_size new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) x = random.randint(0, np.maximum(0, new_w - crop_size)) y = random.randint(0, np.maximum(0, new_h - crop_size)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True): transform_list = [] transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method))) if flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) return transforms.Compose(transform_list) def get_tensor(normalize=True, toTensor=True): transform_list = [] if toTensor: transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __scale(img, target_width, method=Image.BICUBIC): return img.resize((target_width, target_width), method) def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img