from __future__ import division import random import sys from PIL import Image try: import accimage except ImportError: accimage = None import collections import numbers from torchvision.transforms import functional as F if sys.version_info < (3, 3): Sequence = collections.Sequence Iterable = collections.Iterable else: Sequence = collections.abc.Sequence Iterable = collections.abc.Iterable _pil_interpolation_to_str = { Image.NEAREST: "PIL.Image.NEAREST", Image.BILINEAR: "PIL.Image.BILINEAR", Image.BICUBIC: "PIL.Image.BICUBIC", Image.LANCZOS: "PIL.Image.LANCZOS", Image.HAMMING: "PIL.Image.HAMMING", Image.BOX: "PIL.Image.BOX", } class Compose(object): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img, tgt): for t in self.transforms: img, tgt = t(img, tgt) return img, tgt def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string class Resize(object): """Resize the input PIL Image to the given size. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=Image.BILINEAR): assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) self.size = size self.interpolation = interpolation def __call__(self, img, tgt): """ Args: img (PIL Image): Image to be scaled. Returns: PIL Image: Rescaled image. """ return F.resize(img, self.size, self.interpolation), F.resize( tgt, self.size, Image.NEAREST ) def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] return self.__class__.__name__ + "(size={0}, interpolation={1})".format( self.size, interpolate_str ) class CenterCrop(object): """Crops the given PIL Image at the center. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. """ def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, img, tgt): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ return F.center_crop(img, self.size), F.center_crop(tgt, self.size) def __repr__(self): return self.__class__.__name__ + "(size={0})".format(self.size) class RandomCrop(object): """Crop the given PIL Image at a random location. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. padding (int or sequence, optional): Optional padding on each border of the image. Default is None, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is used to pad left/right, top/bottom borders, respectively. pad_if_needed (boolean): It will pad the image if smaller than the desired size to avoid raising an exception. fill: Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value on the edge of the image - reflect: pads with reflection of image (without repeating the last value on the edge) padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image (repeating the last value on the edge) padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode will result in [2, 1, 1, 2, 3, 4, 4, 3] """ def __init__( self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant" ): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode @staticmethod def get_params(img, output_size): """Get parameters for ``crop`` for a random crop. Args: img (PIL Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = img.size th, tw = output_size if w == tw and h == th: return 0, 0, h, w i = random.randint(0, h - th) j = random.randint(0, w - tw) return i, j, th, tw def __call__(self, img, tgt): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode) # pad the width if needed if self.pad_if_needed and img.size[0] < self.size[1]: img = F.pad( img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode ) tgt = F.pad( tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode ) # pad the height if needed if self.pad_if_needed and img.size[1] < self.size[0]: img = F.pad( img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode ) tgt = F.pad( tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode ) i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w) def __repr__(self): return self.__class__.__name__ + "(size={0}, padding={1})".format( self.size, self.padding ) class RandomHorizontalFlip(object): """Horizontally flip the given PIL Image randomly with a given probability. Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): self.p = p def __call__(self, img, tgt): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Randomly flipped image. """ if random.random() < self.p: return F.hflip(img), F.hflip(tgt) return img, tgt def __repr__(self): return self.__class__.__name__ + "(p={})".format(self.p) class RandomVerticalFlip(object): """Vertically flip the given PIL Image randomly with a given probability. Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): self.p = p def __call__(self, img, tgt): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Randomly flipped image. """ if random.random() < self.p: return F.vflip(img), F.vflip(tgt) return img, tgt def __repr__(self): return self.__class__.__name__ + "(p={})".format(self.p) class Lambda(object): """Apply a user-defined lambda as a transform. Args: lambd (function): Lambda/function to be used for transform. """ def __init__(self, lambd): assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" self.lambd = lambd def __call__(self, img, tgt): return self.lambd(img, tgt) def __repr__(self): return self.__class__.__name__ + "()" class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers. contrast (float or tuple of float (min, max)): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non negative numbers. saturation (float or tuple of float (min, max)): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers. hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = self._check_input(brightness, "brightness") self.contrast = self._check_input(contrast, "contrast") self.saturation = self._check_input(saturation, "saturation") self.hue = self._check_input( hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False ) def _check_input( self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True ): if isinstance(value, numbers.Number): if value < 0: raise ValueError( "If {} is a single number, it must be non negative.".format(name) ) value = [center - value, center + value] if clip_first_on_zero: value[0] = max(value[0], 0) elif isinstance(value, (tuple, list)) and len(value) == 2: if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError("{} values should be between {}".format(name, bound)) else: raise TypeError( "{} should be a single number or a list/tuple with lenght 2.".format( name ) ) # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None return value @staticmethod def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. Arguments are same as that of __init__. Returns: Transform which randomly adjusts brightness, contrast and saturation in a random order. """ transforms = [] if brightness is not None: brightness_factor = random.uniform(brightness[0], brightness[1]) transforms.append( Lambda( lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt) ) ) if contrast is not None: contrast_factor = random.uniform(contrast[0], contrast[1]) transforms.append( Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)) ) if saturation is not None: saturation_factor = random.uniform(saturation[0], saturation[1]) transforms.append( Lambda( lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt) ) ) if hue is not None: hue_factor = random.uniform(hue[0], hue[1]) transforms.append( Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)) ) random.shuffle(transforms) transform = Compose(transforms) return transform def __call__(self, img, tgt): """ Args: img (PIL Image): Input image. Returns: PIL Image: Color jittered image. """ transform = self.get_params( self.brightness, self.contrast, self.saturation, self.hue ) return transform(img, tgt) def __repr__(self): format_string = self.__class__.__name__ + "(" format_string += "brightness={0}".format(self.brightness) format_string += ", contrast={0}".format(self.contrast) format_string += ", saturation={0}".format(self.saturation) format_string += ", hue={0})".format(self.hue) return format_string class Normalize(object): """Normalize a tensor image with mean and standard deviation. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e. ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` .. note:: This transform acts out of place, i.e., it does not mutates the input tensor. Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. """ def __init__(self, mean, std, inplace=False): self.mean = mean self.std = std self.inplace = inplace def __call__(self, img, tgt): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized Tensor image. """ # return F.normalize(img, self.mean, self.std, self.inplace), tgt return F.normalize(img, self.mean, self.std), tgt def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( self.mean, self.std ) class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. """ def __call__(self, img, tgt): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.to_tensor(img), tgt def __repr__(self): return self.__class__.__name__ + "()"