|
from __future__ import division |
|
import sys |
|
import random |
|
from PIL import Image |
|
|
|
try: |
|
import accimage |
|
except ImportError: |
|
accimage = None |
|
import numbers |
|
import collections |
|
|
|
from torchvision.transforms import functional as F |
|
|
|
if sys.version_info < (3, 3): |
|
Sequence = collections.Sequence |
|
Iterable = collections.Iterable |
|
else: |
|
Sequence = collections.abc.Sequence |
|
Iterable = collections.abc.Iterable |
|
|
|
_pil_interpolation_to_str = { |
|
Image.NEAREST: 'PIL.Image.NEAREST', |
|
Image.BILINEAR: 'PIL.Image.BILINEAR', |
|
Image.BICUBIC: 'PIL.Image.BICUBIC', |
|
Image.LANCZOS: 'PIL.Image.LANCZOS', |
|
Image.HAMMING: 'PIL.Image.HAMMING', |
|
Image.BOX: 'PIL.Image.BOX', |
|
} |
|
|
|
|
|
class Compose(object): |
|
"""Composes several transforms together. |
|
|
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms to compose. |
|
|
|
Example: |
|
>>> transforms.Compose([ |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.ToTensor(), |
|
>>> ]) |
|
""" |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, img, tgt): |
|
for t in self.transforms: |
|
img, tgt = t(img, tgt) |
|
return img, tgt |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(' |
|
for t in self.transforms: |
|
format_string += '\n' |
|
format_string += ' {0}'.format(t) |
|
format_string += '\n)' |
|
return format_string |
|
|
|
|
|
class Resize(object): |
|
"""Resize the input PIL Image to the given size. |
|
|
|
Args: |
|
size (sequence or int): Desired output size. If size is a sequence like |
|
(h, w), output size will be matched to this. If size is an int, |
|
smaller edge of the image will be matched to this number. |
|
i.e, if height > width, then image will be rescaled to |
|
(size * height / width, size) |
|
interpolation (int, optional): Desired interpolation. Default is |
|
``PIL.Image.BILINEAR`` |
|
""" |
|
|
|
def __init__(self, size, interpolation=Image.BILINEAR): |
|
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) |
|
self.size = size |
|
self.interpolation = interpolation |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be scaled. |
|
|
|
Returns: |
|
PIL Image: Rescaled image. |
|
""" |
|
return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST) |
|
|
|
def __repr__(self): |
|
interpolate_str = _pil_interpolation_to_str[self.interpolation] |
|
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) |
|
|
|
|
|
class CenterCrop(object): |
|
"""Crops the given PIL Image at the center. |
|
|
|
Args: |
|
size (sequence or int): Desired output size of the crop. If size is an |
|
int instead of sequence like (h, w), a square crop (size, size) is |
|
made. |
|
""" |
|
|
|
def __init__(self, size): |
|
if isinstance(size, numbers.Number): |
|
self.size = (int(size), int(size)) |
|
else: |
|
self.size = size |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
|
|
Returns: |
|
PIL Image: Cropped image. |
|
""" |
|
return F.center_crop(img, self.size), F.center_crop(tgt, self.size) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(size={0})'.format(self.size) |
|
|
|
|
|
class RandomCrop(object): |
|
"""Crop the given PIL Image at a random location. |
|
|
|
Args: |
|
size (sequence or int): Desired output size of the crop. If size is an |
|
int instead of sequence like (h, w), a square crop (size, size) is |
|
made. |
|
padding (int or sequence, optional): Optional padding on each border |
|
of the image. Default is None, i.e no padding. If a sequence of length |
|
4 is provided, it is used to pad left, top, right, bottom borders |
|
respectively. If a sequence of length 2 is provided, it is used to |
|
pad left/right, top/bottom borders, respectively. |
|
pad_if_needed (boolean): It will pad the image if smaller than the |
|
desired size to avoid raising an exception. |
|
fill: Pixel fill value for constant fill. Default is 0. If a tuple of |
|
length 3, it is used to fill R, G, B channels respectively. |
|
This value is only used when the padding_mode is constant |
|
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. |
|
|
|
- constant: pads with a constant value, this value is specified with fill |
|
|
|
- edge: pads with the last value on the edge of the image |
|
|
|
- reflect: pads with reflection of image (without repeating the last value on the edge) |
|
|
|
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode |
|
will result in [3, 2, 1, 2, 3, 4, 3, 2] |
|
|
|
- symmetric: pads with reflection of image (repeating the last value on the edge) |
|
|
|
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode |
|
will result in [2, 1, 1, 2, 3, 4, 4, 3] |
|
|
|
""" |
|
|
|
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): |
|
if isinstance(size, numbers.Number): |
|
self.size = (int(size), int(size)) |
|
else: |
|
self.size = size |
|
self.padding = padding |
|
self.pad_if_needed = pad_if_needed |
|
self.fill = fill |
|
self.padding_mode = padding_mode |
|
|
|
@staticmethod |
|
def get_params(img, output_size): |
|
"""Get parameters for ``crop`` for a random crop. |
|
|
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
output_size (tuple): Expected output size of the crop. |
|
|
|
Returns: |
|
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. |
|
""" |
|
w, h = img.size |
|
th, tw = output_size |
|
if w == tw and h == th: |
|
return 0, 0, h, w |
|
|
|
i = random.randint(0, h - th) |
|
j = random.randint(0, w - tw) |
|
return i, j, th, tw |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped. |
|
|
|
Returns: |
|
PIL Image: Cropped image. |
|
""" |
|
if self.padding is not None: |
|
img = F.pad(img, self.padding, self.fill, self.padding_mode) |
|
tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode) |
|
|
|
|
|
if self.pad_if_needed and img.size[0] < self.size[1]: |
|
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) |
|
tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) |
|
|
|
if self.pad_if_needed and img.size[1] < self.size[0]: |
|
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) |
|
tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) |
|
|
|
i, j, h, w = self.get_params(img, self.size) |
|
|
|
return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) |
|
|
|
|
|
class RandomHorizontalFlip(object): |
|
"""Horizontally flip the given PIL Image randomly with a given probability. |
|
|
|
Args: |
|
p (float): probability of the image being flipped. Default value is 0.5 |
|
""" |
|
|
|
def __init__(self, p=0.5): |
|
self.p = p |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be flipped. |
|
|
|
Returns: |
|
PIL Image: Randomly flipped image. |
|
""" |
|
if random.random() < self.p: |
|
return F.hflip(img), F.hflip(tgt) |
|
|
|
return img, tgt |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(p={})'.format(self.p) |
|
|
|
|
|
class RandomVerticalFlip(object): |
|
"""Vertically flip the given PIL Image randomly with a given probability. |
|
|
|
Args: |
|
p (float): probability of the image being flipped. Default value is 0.5 |
|
""" |
|
|
|
def __init__(self, p=0.5): |
|
self.p = p |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be flipped. |
|
|
|
Returns: |
|
PIL Image: Randomly flipped image. |
|
""" |
|
if random.random() < self.p: |
|
return F.vflip(img), F.vflip(tgt) |
|
return img, tgt |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(p={})'.format(self.p) |
|
|
|
|
|
class Lambda(object): |
|
"""Apply a user-defined lambda as a transform. |
|
|
|
Args: |
|
lambd (function): Lambda/function to be used for transform. |
|
""" |
|
|
|
def __init__(self, lambd): |
|
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" |
|
self.lambd = lambd |
|
|
|
def __call__(self, img, tgt): |
|
return self.lambd(img, tgt) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '()' |
|
|
|
|
|
class ColorJitter(object): |
|
"""Randomly change the brightness, contrast and saturation of an image. |
|
|
|
Args: |
|
brightness (float or tuple of float (min, max)): How much to jitter brightness. |
|
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] |
|
or the given [min, max]. Should be non negative numbers. |
|
contrast (float or tuple of float (min, max)): How much to jitter contrast. |
|
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] |
|
or the given [min, max]. Should be non negative numbers. |
|
saturation (float or tuple of float (min, max)): How much to jitter saturation. |
|
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] |
|
or the given [min, max]. Should be non negative numbers. |
|
hue (float or tuple of float (min, max)): How much to jitter hue. |
|
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. |
|
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. |
|
""" |
|
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): |
|
self.brightness = self._check_input(brightness, 'brightness') |
|
self.contrast = self._check_input(contrast, 'contrast') |
|
self.saturation = self._check_input(saturation, 'saturation') |
|
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), |
|
clip_first_on_zero=False) |
|
|
|
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): |
|
if isinstance(value, numbers.Number): |
|
if value < 0: |
|
raise ValueError("If {} is a single number, it must be non negative.".format(name)) |
|
value = [center - value, center + value] |
|
if clip_first_on_zero: |
|
value[0] = max(value[0], 0) |
|
elif isinstance(value, (tuple, list)) and len(value) == 2: |
|
if not bound[0] <= value[0] <= value[1] <= bound[1]: |
|
raise ValueError("{} values should be between {}".format(name, bound)) |
|
else: |
|
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) |
|
|
|
|
|
|
|
if value[0] == value[1] == center: |
|
value = None |
|
return value |
|
|
|
@staticmethod |
|
def get_params(brightness, contrast, saturation, hue): |
|
"""Get a randomized transform to be applied on image. |
|
|
|
Arguments are same as that of __init__. |
|
|
|
Returns: |
|
Transform which randomly adjusts brightness, contrast and |
|
saturation in a random order. |
|
""" |
|
transforms = [] |
|
|
|
if brightness is not None: |
|
brightness_factor = random.uniform(brightness[0], brightness[1]) |
|
transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt))) |
|
|
|
if contrast is not None: |
|
contrast_factor = random.uniform(contrast[0], contrast[1]) |
|
transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt))) |
|
|
|
if saturation is not None: |
|
saturation_factor = random.uniform(saturation[0], saturation[1]) |
|
transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt))) |
|
|
|
if hue is not None: |
|
hue_factor = random.uniform(hue[0], hue[1]) |
|
transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt))) |
|
|
|
random.shuffle(transforms) |
|
transform = Compose(transforms) |
|
|
|
return transform |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
img (PIL Image): Input image. |
|
|
|
Returns: |
|
PIL Image: Color jittered image. |
|
""" |
|
transform = self.get_params(self.brightness, self.contrast, |
|
self.saturation, self.hue) |
|
return transform(img, tgt) |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(' |
|
format_string += 'brightness={0}'.format(self.brightness) |
|
format_string += ', contrast={0}'.format(self.contrast) |
|
format_string += ', saturation={0}'.format(self.saturation) |
|
format_string += ', hue={0})'.format(self.hue) |
|
return format_string |
|
|
|
|
|
class Normalize(object): |
|
"""Normalize a tensor image with mean and standard deviation. |
|
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform |
|
will normalize each channel of the input ``torch.*Tensor`` i.e. |
|
``input[channel] = (input[channel] - mean[channel]) / std[channel]`` |
|
|
|
.. note:: |
|
This transform acts out of place, i.e., it does not mutates the input tensor. |
|
|
|
Args: |
|
mean (sequence): Sequence of means for each channel. |
|
std (sequence): Sequence of standard deviations for each channel. |
|
""" |
|
|
|
def __init__(self, mean, std, inplace=False): |
|
self.mean = mean |
|
self.std = std |
|
self.inplace = inplace |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. |
|
|
|
Returns: |
|
Tensor: Normalized Tensor image. |
|
""" |
|
|
|
return F.normalize(img, self.mean, self.std), tgt |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) |
|
|
|
|
|
class ToTensor(object): |
|
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. |
|
|
|
Converts a PIL Image or numpy.ndarray (H x W x C) in the range |
|
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] |
|
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) |
|
or if the numpy.ndarray has dtype = np.uint8 |
|
|
|
In the other cases, tensors are returned without scaling. |
|
""" |
|
|
|
def __call__(self, img, tgt): |
|
""" |
|
Args: |
|
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. |
|
|
|
Returns: |
|
Tensor: Converted image. |
|
""" |
|
return F.to_tensor(img), tgt |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '()' |