|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Sequence |
|
|
|
import torch |
|
from torchvision import transforms |
|
|
|
|
|
class GaussianBlur(transforms.RandomApply): |
|
""" |
|
Apply Gaussian Blur to the PIL image. |
|
""" |
|
|
|
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): |
|
|
|
keep_p = 1 - p |
|
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) |
|
super().__init__(transforms=[transform], p=keep_p) |
|
|
|
|
|
class MaybeToTensor(transforms.ToTensor): |
|
""" |
|
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. |
|
""" |
|
|
|
def __call__(self, pic): |
|
""" |
|
Args: |
|
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. |
|
Returns: |
|
Tensor: Converted image. |
|
""" |
|
if isinstance(pic, torch.Tensor): |
|
return pic |
|
return super().__call__(pic) |
|
|
|
|
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
def make_normalize_transform( |
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
|
std: Sequence[float] = IMAGENET_DEFAULT_STD, |
|
) -> transforms.Normalize: |
|
return transforms.Normalize(mean=mean, std=std) |
|
|
|
|
|
|
|
|
|
def make_classification_train_transform( |
|
*, |
|
crop_size: int = 224, |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
hflip_prob: float = 0.5, |
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
|
std: Sequence[float] = IMAGENET_DEFAULT_STD, |
|
): |
|
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] |
|
if hflip_prob > 0.0: |
|
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) |
|
transforms_list.extend( |
|
[ |
|
MaybeToTensor(), |
|
make_normalize_transform(mean=mean, std=std), |
|
] |
|
) |
|
return transforms.Compose(transforms_list) |
|
|
|
|
|
|
|
|
|
def make_classification_eval_transform( |
|
*, |
|
resize_size: int = 256, |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
crop_size: int = 224, |
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
|
std: Sequence[float] = IMAGENET_DEFAULT_STD, |
|
) -> transforms.Compose: |
|
transforms_list = [ |
|
transforms.Resize(resize_size, interpolation=interpolation), |
|
transforms.CenterCrop(crop_size), |
|
MaybeToTensor(), |
|
make_normalize_transform(mean=mean, std=std), |
|
] |
|
return transforms.Compose(transforms_list) |
|
|