# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Sequence import torch from torchvision import transforms class GaussianBlur(transforms.RandomApply): """ Apply Gaussian Blur to the PIL image. """ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): # NOTE: torchvision is applying 1 - probability to return the original image keep_p = 1 - p transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) super().__init__(transforms=[transform], p=keep_p) class MaybeToTensor(transforms.ToTensor): """ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. """ def __call__(self, pic): """ Args: pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. Returns: Tensor: Converted image. """ if isinstance(pic, torch.Tensor): return pic return super().__call__(pic) # Use timm's names IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) def make_normalize_transform( mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, std: Sequence[float] = IMAGENET_DEFAULT_STD, ) -> transforms.Normalize: return transforms.Normalize(mean=mean, std=std) # This roughly matches torchvision's preset for classification training: # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 def make_classification_train_transform( *, crop_size: int = 224, interpolation=transforms.InterpolationMode.BICUBIC, hflip_prob: float = 0.5, mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, std: Sequence[float] = IMAGENET_DEFAULT_STD, ): transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if hflip_prob > 0.0: transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) transforms_list.extend( [ MaybeToTensor(), make_normalize_transform(mean=mean, std=std), ] ) return transforms.Compose(transforms_list) # This matches (roughly) torchvision's preset for classification evaluation: # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 def make_classification_eval_transform( *, resize_size: int = 256, interpolation=transforms.InterpolationMode.BICUBIC, crop_size: int = 224, mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, std: Sequence[float] = IMAGENET_DEFAULT_STD, ) -> transforms.Compose: transforms_list = [ transforms.Resize(resize_size, interpolation=interpolation), transforms.CenterCrop(crop_size), MaybeToTensor(), make_normalize_transform(mean=mean, std=std), ] return transforms.Compose(transforms_list)