import torch import torchvision import random from PIL import Image, ImageOps import numpy as np import numbers import math class GroupRandomCrop(object): def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, img_group): w, h = img_group[0].size th, tw = self.size out_images = list() x1 = random.randint(0, w - tw) y1 = random.randint(0, h - th) for img in img_group: assert(img.size[0] == w and img.size[1] == h) if w == tw and h == th: out_images.append(img) else: out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) return out_images class MultiGroupRandomCrop(object): def __init__(self, size, groups=1): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.groups = groups def __call__(self, img_group): w, h = img_group[0].size th, tw = self.size out_images = list() for i in range(self.groups): x1 = random.randint(0, w - tw) y1 = random.randint(0, h - th) for img in img_group: assert(img.size[0] == w and img.size[1] == h) if w == tw and h == th: out_images.append(img) else: out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) return out_images class GroupCenterCrop(object): def __init__(self, size): self.worker = torchvision.transforms.CenterCrop(size) def __call__(self, img_group): return [self.worker(img) for img in img_group] class GroupRandomHorizontalFlip(object): """Randomly horizontally flips the given PIL.Image with a probability of 0.5 """ def __init__(self, is_flow=False): self.is_flow = is_flow def __call__(self, img_group, is_flow=False): v = random.random() if v < 0.5: ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] if self.is_flow: for i in range(0, len(ret), 2): # invert flow pixel values when flipping ret[i] = ImageOps.invert(ret[i]) return ret else: return img_group class GroupNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) rep_std = self.std * (tensor.size()[0] // len(self.std)) # TODO: make efficient for t, m, s in zip(tensor, rep_mean, rep_std): t.sub_(m).div_(s) return tensor class GroupScale(object): """ Rescales the input PIL.Image to the given 'size'. 'size' will be the size of the smaller edge. For example, if height > width, then image will be rescaled to (size * height / width, size) size: size of the smaller edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): self.worker = torchvision.transforms.Resize(size, interpolation) def __call__(self, img_group): return [self.worker(img) for img in img_group] class GroupOverSample(object): def __init__(self, crop_size, scale_size=None, flip=True): self.crop_size = crop_size if not isinstance( crop_size, int) else (crop_size, crop_size) if scale_size is not None: self.scale_worker = GroupScale(scale_size) else: self.scale_worker = None self.flip = flip def __call__(self, img_group): if self.scale_worker is not None: img_group = self.scale_worker(img_group) image_w, image_h = img_group[0].size crop_w, crop_h = self.crop_size offsets = GroupMultiScaleCrop.fill_fix_offset( False, image_w, image_h, crop_w, crop_h) oversample_group = list() for o_w, o_h in offsets: normal_group = list() flip_group = list() for i, img in enumerate(img_group): crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) normal_group.append(crop) flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) if img.mode == 'L' and i % 2 == 0: flip_group.append(ImageOps.invert(flip_crop)) else: flip_group.append(flip_crop) oversample_group.extend(normal_group) if self.flip: oversample_group.extend(flip_group) return oversample_group class GroupFullResSample(object): def __init__(self, crop_size, scale_size=None, flip=True): self.crop_size = crop_size if not isinstance( crop_size, int) else (crop_size, crop_size) if scale_size is not None: self.scale_worker = GroupScale(scale_size) else: self.scale_worker = None self.flip = flip def __call__(self, img_group): if self.scale_worker is not None: img_group = self.scale_worker(img_group) image_w, image_h = img_group[0].size crop_w, crop_h = self.crop_size w_step = (image_w - crop_w) // 4 h_step = (image_h - crop_h) // 4 offsets = list() offsets.append((0 * w_step, 2 * h_step)) # left offsets.append((4 * w_step, 2 * h_step)) # right offsets.append((2 * w_step, 2 * h_step)) # center oversample_group = list() for o_w, o_h in offsets: normal_group = list() flip_group = list() for i, img in enumerate(img_group): crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) normal_group.append(crop) if self.flip: flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) if img.mode == 'L' and i % 2 == 0: flip_group.append(ImageOps.invert(flip_crop)) else: flip_group.append(flip_crop) oversample_group.extend(normal_group) oversample_group.extend(flip_group) return oversample_group class GroupMultiScaleCrop(object): def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): self.scales = scales if scales is not None else [1, .875, .75, .66] self.max_distort = max_distort self.fix_crop = fix_crop self.more_fix_crop = more_fix_crop self.input_size = input_size if not isinstance(input_size, int) else [ input_size, input_size] self.interpolation = Image.BILINEAR def __call__(self, img_group): im_size = img_group[0].size crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) crop_img_group = [ img.crop( (offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] return ret_img_group def _sample_crop_size(self, im_size): image_w, image_h = im_size[0], im_size[1] # find a crop size base_size = min(image_w, image_h) crop_sizes = [int(base_size * x) for x in self.scales] crop_h = [ self.input_size[1] if abs( x - self.input_size[1]) < 3 else x for x in crop_sizes] crop_w = [ self.input_size[0] if abs( x - self.input_size[0]) < 3 else x for x in crop_sizes] pairs = [] for i, h in enumerate(crop_h): for j, w in enumerate(crop_w): if abs(i - j) <= self.max_distort: pairs.append((w, h)) crop_pair = random.choice(pairs) if not self.fix_crop: w_offset = random.randint(0, image_w - crop_pair[0]) h_offset = random.randint(0, image_h - crop_pair[1]) else: w_offset, h_offset = self._sample_fix_offset( image_w, image_h, crop_pair[0], crop_pair[1]) return crop_pair[0], crop_pair[1], w_offset, h_offset def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): offsets = self.fill_fix_offset( self.more_fix_crop, image_w, image_h, crop_w, crop_h) return random.choice(offsets) @staticmethod def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): w_step = (image_w - crop_w) // 4 h_step = (image_h - crop_h) // 4 ret = list() ret.append((0, 0)) # upper left ret.append((4 * w_step, 0)) # upper right ret.append((0, 4 * h_step)) # lower left ret.append((4 * w_step, 4 * h_step)) # lower right ret.append((2 * w_step, 2 * h_step)) # center if more_fix_crop: ret.append((0, 2 * h_step)) # center left ret.append((4 * w_step, 2 * h_step)) # center right ret.append((2 * w_step, 4 * h_step)) # lower center ret.append((2 * w_step, 0 * h_step)) # upper center ret.append((1 * w_step, 1 * h_step)) # upper left quarter ret.append((3 * w_step, 1 * h_step)) # upper right quarter ret.append((1 * w_step, 3 * h_step)) # lower left quarter ret.append((3 * w_step, 3 * h_step)) # lower righ quarter return ret class GroupRandomSizedCrop(object): """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio This is popularly used to train the Inception networks size: size of the smaller edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation def __call__(self, img_group): for attempt in range(10): area = img_group[0].size[0] * img_group[0].size[1] target_area = random.uniform(0.08, 1.0) * area aspect_ratio = random.uniform(3. / 4, 4. / 3) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if random.random() < 0.5: w, h = h, w if w <= img_group[0].size[0] and h <= img_group[0].size[1]: x1 = random.randint(0, img_group[0].size[0] - w) y1 = random.randint(0, img_group[0].size[1] - h) found = True break else: found = False x1 = 0 y1 = 0 if found: out_group = list() for img in img_group: img = img.crop((x1, y1, x1 + w, y1 + h)) assert(img.size == (w, h)) out_group.append( img.resize( (self.size, self.size), self.interpolation)) return out_group else: # Fallback scale = GroupScale(self.size, interpolation=self.interpolation) crop = GroupRandomCrop(self.size) return crop(scale(img_group)) class ConvertDataFormat(object): def __init__(self, model_type): self.model_type = model_type def __call__(self, images): if self.model_type == '2D': return images tc, h, w = images.size() t = tc // 3 images = images.view(t, 3, h, w) images = images.permute(1, 0, 2, 3) return images class Stack(object): def __init__(self, roll=False): self.roll = roll def __call__(self, img_group): if img_group[0].mode == 'L': return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) elif img_group[0].mode == 'RGB': if self.roll: return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) else: #print(np.concatenate(img_group, axis=2).shape) # print(img_group[0].shape) return np.concatenate(img_group, axis=2) class ToTorchFormatTensor(object): """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ def __init__(self, div=True): self.div = div def __call__(self, pic): if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() else: # handle PIL Image img = torch.ByteTensor( torch.ByteStorage.from_buffer( pic.tobytes())) img = img.view(pic.size[1], pic.size[0], len(pic.mode)) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() return img.float().div(255) if self.div else img.float() class IdentityTransform(object): def __call__(self, data): return data if __name__ == "__main__": trans = torchvision.transforms.Compose([ GroupScale(256), GroupRandomCrop(224), Stack(), ToTorchFormatTensor(), GroupNormalize( mean=[.485, .456, .406], std=[.229, .224, .225] )] ) im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') color_group = [im] * 3 rst = trans(color_group) gray_group = [im.convert('L')] * 9 gray_rst = trans(gray_group) trans2 = torchvision.transforms.Compose([ GroupRandomSizedCrop(256), Stack(), ToTorchFormatTensor(), GroupNormalize( mean=[.485, .456, .406], std=[.229, .224, .225]) ]) print(trans2(color_group))