# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Transforms used in the Augmentation Policies.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import random import numpy as np # pylint:disable=g-multiple-import from PIL import ImageOps, ImageEnhance, ImageFilter, Image # pylint:enable=g-multiple-import IMAGE_SIZE = 32 # What is the dataset mean and std of the images on the training set MEANS = [0.49139968, 0.48215841, 0.44653091] STDS = [0.24703223, 0.24348513, 0.26158784] PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted def random_flip(x): """Flip the input x horizontally with 50% probability.""" if np.random.rand(1)[0] > 0.5: return np.fliplr(x) return x def zero_pad_and_crop(img, amount=4): """Zero pad by `amount` zero pixels on each side then take a random crop. Args: img: numpy image that will be zero padded and cropped. amount: amount of zeros to pad `img` with horizontally and verically. Returns: The cropped zero padded img. The returned numpy array will be of the same shape as `img`. """ padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2, img.shape[2])) padded_img[amount:img.shape[0] + amount, amount: img.shape[1] + amount, :] = img top = np.random.randint(low=0, high=2 * amount) left = np.random.randint(low=0, high=2 * amount) new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :] return new_img def create_cutout_mask(img_height, img_width, num_channels, size): """Creates a zero mask used for cutout of shape `img_height` x `img_width`. Args: img_height: Height of image cutout mask will be applied to. img_width: Width of image cutout mask will be applied to. num_channels: Number of channels in the image. size: Size of the zeros mask. Returns: A mask of shape `img_height` x `img_width` with all ones except for a square of zeros of shape `size` x `size`. This mask is meant to be elementwise multiplied with the original image. Additionally returns the `upper_coord` and `lower_coord` which specify where the cutout mask will be applied. """ assert img_height == img_width # Sample center where cutout mask will be applied height_loc = np.random.randint(low=0, high=img_height) width_loc = np.random.randint(low=0, high=img_width) # Determine upper right and lower left corners of patch upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) lower_coord = (min(img_height, height_loc + size // 2), min(img_width, width_loc + size // 2)) mask_height = lower_coord[0] - upper_coord[0] mask_width = lower_coord[1] - upper_coord[1] assert mask_height > 0 assert mask_width > 0 mask = np.ones((img_height, img_width, num_channels)) zeros = np.zeros((mask_height, mask_width, num_channels)) mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = ( zeros) return mask, upper_coord, lower_coord def cutout_numpy(img, size=16): """Apply cutout with mask of shape `size` x `size` to `img`. The cutout operation is from the paper https://arxiv.org/abs/1708.04552. This operation applies a `size`x`size` mask of zeros to a random location within `img`. Args: img: Numpy image that cutout will be applied to. size: Height/width of the cutout mask that will be Returns: A numpy tensor that is the result of applying the cutout mask to `img`. """ img_height, img_width, num_channels = (img.shape[0], img.shape[1], img.shape[2]) assert len(img.shape) == 3 mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size) return img * mask def float_parameter(level, maxval): """Helper function to scale `val` between 0 and maxval . Args: level: Level of the operation that will be between [0, `PARAMETER_MAX`]. maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX. Returns: A float that results from scaling `maxval` according to `level`. """ return float(level) * maxval / PARAMETER_MAX def int_parameter(level, maxval): """Helper function to scale `val` between 0 and maxval . Args: level: Level of the operation that will be between [0, `PARAMETER_MAX`]. maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX. Returns: An int that results from scaling `maxval` according to `level`. """ return int(level * maxval / PARAMETER_MAX) def pil_wrap(img): """Convert the `img` numpy tensor to a PIL Image.""" return Image.fromarray( np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA') def pil_unwrap(pil_img): """Converts the PIL img to a numpy array.""" pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0) i1, i2 = np.where(pic_array[:, :, 3] == 0) pic_array = (pic_array[:, :, :3] - MEANS) / STDS pic_array[i1, i2] = [0, 0, 0] return pic_array def apply_policy(policy, img): """Apply the `policy` to the numpy `img`. Args: policy: A list of tuples with the form (name, probability, level) where `name` is the name of the augmentation operation to apply, `probability` is the probability of applying the operation and `level` is what strength the operation to apply. img: Numpy image that will have `policy` applied to it. Returns: The result of applying `policy` to `img`. """ pil_img = pil_wrap(img) for xform in policy: assert len(xform) == 3 name, probability, level = xform xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level) pil_img = xform_fn(pil_img) return pil_unwrap(pil_img) class TransformFunction(object): """Wraps the Transform function for pretty printing options.""" def __init__(self, func, name): self.f = func self.name = name def __repr__(self): return '<' + self.name + '>' def __call__(self, pil_img): return self.f(pil_img) class TransformT(object): """Each instance of this class represents a specific transform.""" def __init__(self, name, xform_fn): self.name = name self.xform = xform_fn def pil_transformer(self, probability, level): def return_function(im): if random.random() < probability: im = self.xform(im, level) return im name = self.name + '({:.1f},{})'.format(probability, level) return TransformFunction(return_function, name) def do_transform(self, image, level): f = self.pil_transformer(PARAMETER_MAX, level) return pil_unwrap(f(pil_wrap(image))) ################## Transform Functions ################## identity = TransformT('identity', lambda pil_img, level: pil_img) flip_lr = TransformT( 'FlipLR', lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT)) flip_ud = TransformT( 'FlipUD', lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM)) # pylint:disable=g-long-lambda auto_contrast = TransformT( 'AutoContrast', lambda pil_img, level: ImageOps.autocontrast( pil_img.convert('RGB')).convert('RGBA')) equalize = TransformT( 'Equalize', lambda pil_img, level: ImageOps.equalize( pil_img.convert('RGB')).convert('RGBA')) invert = TransformT( 'Invert', lambda pil_img, level: ImageOps.invert( pil_img.convert('RGB')).convert('RGBA')) # pylint:enable=g-long-lambda blur = TransformT( 'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR)) smooth = TransformT( 'Smooth', lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH)) def _rotate_impl(pil_img, level): """Rotates `pil_img` from -30 to 30 degrees depending on `level`.""" degrees = int_parameter(level, 30) if random.random() > 0.5: degrees = -degrees return pil_img.rotate(degrees) rotate = TransformT('Rotate', _rotate_impl) def _posterize_impl(pil_img, level): """Applies PIL Posterize to `pil_img`.""" level = int_parameter(level, 4) return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA') posterize = TransformT('Posterize', _posterize_impl) def _shear_x_impl(pil_img, level): """Applies PIL ShearX to `pil_img`. The ShearX operation shears the image along the horizontal axis with `level` magnitude. Args: pil_img: Image in PIL object. level: Strength of the operation specified as an Integer from [0, `PARAMETER_MAX`]. Returns: A PIL Image that has had ShearX applied to it. """ level = float_parameter(level, 0.3) if random.random() > 0.5: level = -level return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0)) shear_x = TransformT('ShearX', _shear_x_impl) def _shear_y_impl(pil_img, level): """Applies PIL ShearY to `pil_img`. The ShearY operation shears the image along the vertical axis with `level` magnitude. Args: pil_img: Image in PIL object. level: Strength of the operation specified as an Integer from [0, `PARAMETER_MAX`]. Returns: A PIL Image that has had ShearX applied to it. """ level = float_parameter(level, 0.3) if random.random() > 0.5: level = -level return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0)) shear_y = TransformT('ShearY', _shear_y_impl) def _translate_x_impl(pil_img, level): """Applies PIL TranslateX to `pil_img`. Translate the image in the horizontal direction by `level` number of pixels. Args: pil_img: Image in PIL object. level: Strength of the operation specified as an Integer from [0, `PARAMETER_MAX`]. Returns: A PIL Image that has had TranslateX applied to it. """ level = int_parameter(level, 10) if random.random() > 0.5: level = -level return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0)) translate_x = TransformT('TranslateX', _translate_x_impl) def _translate_y_impl(pil_img, level): """Applies PIL TranslateY to `pil_img`. Translate the image in the vertical direction by `level` number of pixels. Args: pil_img: Image in PIL object. level: Strength of the operation specified as an Integer from [0, `PARAMETER_MAX`]. Returns: A PIL Image that has had TranslateY applied to it. """ level = int_parameter(level, 10) if random.random() > 0.5: level = -level return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level)) translate_y = TransformT('TranslateY', _translate_y_impl) def _crop_impl(pil_img, level, interpolation=Image.BILINEAR): """Applies a crop to `pil_img` with the size depending on the `level`.""" cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level)) resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation) return resized crop_bilinear = TransformT('CropBilinear', _crop_impl) def _solarize_impl(pil_img, level): """Applies PIL Solarize to `pil_img`. Translate the image in the vertical direction by `level` number of pixels. Args: pil_img: Image in PIL object. level: Strength of the operation specified as an Integer from [0, `PARAMETER_MAX`]. Returns: A PIL Image that has had Solarize applied to it. """ level = int_parameter(level, 256) return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA') solarize = TransformT('Solarize', _solarize_impl) def _cutout_pil_impl(pil_img, level): """Apply cutout to pil_img at the specified level.""" size = int_parameter(level, 20) if size <= 0: return pil_img img_height, img_width, num_channels = (32, 32, 3) _, upper_coord, lower_coord = ( create_cutout_mask(img_height, img_width, num_channels, size)) pixels = pil_img.load() # create the pixel map for i in range(upper_coord[0], lower_coord[0]): # for every col: for j in range(upper_coord[1], lower_coord[1]): # For every row pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly return pil_img cutout = TransformT('Cutout', _cutout_pil_impl) def _enhancer_impl(enhancer): """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL.""" def impl(pil_img, level): v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it return enhancer(pil_img).enhance(v) return impl color = TransformT('Color', _enhancer_impl(ImageEnhance.Color)) contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast)) brightness = TransformT('Brightness', _enhancer_impl( ImageEnhance.Brightness)) sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness)) ALL_TRANSFORMS = [ flip_lr, flip_ud, auto_contrast, equalize, invert, rotate, posterize, crop_bilinear, solarize, color, contrast, brightness, sharpness, shear_x, shear_y, translate_x, translate_y, cutout, blur, smooth ] NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS} TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()