""" Transforms Factory Factory methods for building image transforms for use with TIMM (PyTorch Image Models) Hacked together by / Copyright 2019, Ross Wightman """ import math from typing import Optional, Tuple, Union import torch from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\ ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy from timm.data.random_erasing import RandomErasing def transforms_noaug_train( img_size: Union[int, Tuple[int, int]] = 224, interpolation: str = 'bilinear', use_prefetcher: bool = False, mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, ): """ No-augmentation image transforms for training. Args: img_size: Target image size. interpolation: Image interpolation mode. mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. Returns: """ if interpolation == 'random': # random interpolation not supported with no-aug interpolation = 'bilinear' tfl = [ transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size) ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std) ) ] return transforms.Compose(tfl) def transforms_imagenet_train( img_size: Union[int, Tuple[int, int]] = 224, scale: Optional[Tuple[float, float]] = None, ratio: Optional[Tuple[float, float]] = None, train_crop_mode: Optional[str] = None, hflip: float = 0.5, vflip: float = 0., color_jitter: Union[float, Tuple[float, ...]] = 0.4, color_jitter_prob: Optional[float] = None, force_color_jitter: bool = False, grayscale_prob: float = 0., gaussian_blur_prob: float = 0., auto_augment: Optional[str] = None, interpolation: str = 'random', mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, re_prob: float = 0., re_mode: str = 'const', re_count: int = 1, re_num_splits: int = 0, use_prefetcher: bool = False, separate: bool = False, ): """ ImageNet-oriented image transforms for training. Args: img_size: Target image size. train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr'). scale: Random resize scale range (crop area, < 1.0 => zoom in). ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR). hflip: Horizontal flip probability. vflip: Vertical flip probability. color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue). Scalar is applied as (scalar,) * 3 (no hue). color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug). force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on). grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug). gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug). auto_augment: Auto augment configuration string (see auto_augment.py). interpolation: Image interpolation mode. mean: Image normalization mean. std: Image normalization standard deviation. re_prob: Random erasing probability. re_mode: Random erasing fill mode. re_count: Number of random erasing regions. re_num_splits: Control split of random erasing across batch size. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. separate: Output transforms in 3-stage tuple. Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms for use in a mixing dataset that passes * all data through the first (primary) transform, called the 'clean' data * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} if train_crop_mode in ('rkrc', 'rkrr'): # FIXME integration of RKR is a WIP scale = tuple(scale or (0.8, 1.00)) ratio = tuple(ratio or (0.9, 1/.9)) primary_tfl = [ ResizeKeepRatio( img_size, interpolation=interpolation, random_scale_prob=0.5, random_scale_range=scale, random_scale_area=True, # scale compatible with RRC random_aspect_prob=0.5, random_aspect_range=ratio, ), CenterCropOrPad(img_size, padding_mode='reflect') if train_crop_mode == 'rkrc' else RandomCropOrPad(img_size, padding_mode='reflect') ] else: scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range primary_tfl = [ RandomResizedCropAndInterpolation( img_size, scale=scale, ratio=ratio, interpolation=interpolation, ) ] if hflip > 0.: primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] disable_color_jitter = False if auto_augment: assert isinstance(auto_augment, str) # color jitter is typically disabled if AA/RA on, # this allows override without breaking old hparm cfgs disable_color_jitter = not (force_color_jitter or '3a' in auto_augment) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: img_size_min = img_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random': aa_params['interpolation'] = str_to_pil_interp(interpolation) if auto_augment.startswith('rand'): secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): aa_params['translate_pct'] = 0.3 secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] if color_jitter is not None and not disable_color_jitter: # color jitter is enabled when not using AA or when forced if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue assert len(color_jitter) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 if color_jitter_prob is not None: secondary_tfl += [ transforms.RandomApply([ transforms.ColorJitter(*color_jitter), ], p=color_jitter_prob ) ] else: secondary_tfl += [transforms.ColorJitter(*color_jitter)] if grayscale_prob: secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)] if gaussian_blur_prob: secondary_tfl += [ transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23), # hardcoded for now ], p=gaussian_blur_prob, ) ] final_tfl = [] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm final_tfl += [ToNumpy()] else: final_tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std) ), ] if re_prob > 0.: final_tfl += [ RandomErasing( re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu', ) ] if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) def transforms_imagenet_eval( img_size: Union[int, Tuple[int, int]] = 224, crop_pct: Optional[float] = None, crop_mode: Optional[str] = None, crop_border_pixels: Optional[int] = None, interpolation: str = 'bilinear', mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference. Args: img_size: Target image size. crop_pct: Crop percentage. Defaults to 0.875 when None. crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None. crop_border_pixels: Trim a border of specified # pixels around edge of original image. interpolation: Image interpolation mode. mean: Image normalization mean. std: Image normalization standard deviation. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. Returns: Composed transform pipeline """ crop_pct = crop_pct or DEFAULT_CROP_PCT if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 scale_size = tuple([math.floor(x / crop_pct) for x in img_size]) else: scale_size = math.floor(img_size / crop_pct) scale_size = (scale_size, scale_size) tfl = [] if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] if crop_mode == 'squash': # squash mode scales each edge to 1/pct of target, then crops # aspect ratio is not preserved, no img lost if crop_pct == 1.0 tfl += [ transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size), ] elif crop_mode == 'border': # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop # no image lost if crop_pct == 1.0 fill = [round(255 * v) for v in mean] tfl += [ ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), CenterCropOrPad(img_size, fill=fill), ] else: # default crop model is center # aspect ratio is preserved, crops center within image, no borders are added, image is lost if scale_size[0] == scale_size[1]: # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) tfl += [ transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) ] else: # resize the shortest edge to matching target dim for non-square target tfl += [ResizeKeepRatio(scale_size)] tfl += [transforms.CenterCrop(img_size)] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] else: tfl += [ transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), ) ] return transforms.Compose(tfl) def create_transform( input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224, is_training: bool = False, no_aug: bool = False, train_crop_mode: Optional[str] = None, scale: Optional[Tuple[float, float]] = None, ratio: Optional[Tuple[float, float]] = None, hflip: float = 0.5, vflip: float = 0., color_jitter: Union[float, Tuple[float, ...]] = 0.4, color_jitter_prob: Optional[float] = None, grayscale_prob: float = 0., gaussian_blur_prob: float = 0., auto_augment: Optional[str] = None, interpolation: str = 'bilinear', mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, re_prob: float = 0., re_mode: str = 'const', re_count: int = 1, re_num_splits: int = 0, crop_pct: Optional[float] = None, crop_mode: Optional[str] = None, crop_border_pixels: Optional[int] = None, tf_preprocessing: bool = False, use_prefetcher: bool = False, separate: bool = False, ): """ Args: input_size: Target input size (channels, height, width) tuple or size scalar. is_training: Return training (random) transforms. no_aug: Disable augmentation for training (useful for debug). train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr'). scale: Random resize scale range (crop area, < 1.0 => zoom in). ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR). hflip: Horizontal flip probability. vflip: Vertical flip probability. color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue). Scalar is applied as (scalar,) * 3 (no hue). color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug). grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug). gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug). auto_augment: Auto augment configuration string (see auto_augment.py). interpolation: Image interpolation mode. mean: Image normalization mean. std: Image normalization standard deviation. re_prob: Random erasing probability. re_mode: Random erasing fill mode. re_count: Number of random erasing regions. re_num_splits: Control split of random erasing across batch size. crop_pct: Inference crop percentage (output size / resize size). crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None. crop_border_pixels: Inference crop border of specified # pixels around edge of original image. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize. separate: Output transforms in 3-stage tuple. Returns: Composed transforms or tuple thereof """ if isinstance(input_size, (tuple, list)): img_size = input_size[-2:] else: img_size = input_size if tf_preprocessing and use_prefetcher: assert not separate, "Separate transforms not supported for TF preprocessing" from timm.data.tf_preprocessing import TfPreprocessTransform transform = TfPreprocessTransform( is_training=is_training, size=img_size, interpolation=interpolation, ) else: if is_training and no_aug: assert not separate, "Cannot perform split augmentation with no_aug" transform = transforms_noaug_train( img_size, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, ) elif is_training: transform = transforms_imagenet_train( img_size, train_crop_mode=train_crop_mode, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, color_jitter_prob=color_jitter_prob, grayscale_prob=grayscale_prob, gaussian_blur_prob=gaussian_blur_prob, auto_augment=auto_augment, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, separate=separate, ) else: assert not separate, "Separate transforms not supported for validation preprocessing" transform = transforms_imagenet_eval( img_size, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, std=std, crop_pct=crop_pct, crop_mode=crop_mode, crop_border_pixels=crop_border_pixels, ) return transform