| |
| |
| |
| |
|
|
| from typing import Sequence |
|
|
| import torch |
| from torchvision import transforms |
|
|
|
|
| class GaussianBlur(transforms.RandomApply): |
| """ |
| Apply Gaussian Blur to the PIL image. |
| """ |
|
|
| def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): |
| |
| keep_p = 1 - p |
| transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) |
| super().__init__(transforms=[transform], p=keep_p) |
|
|
|
|
| class MaybeToTensor(transforms.ToTensor): |
| """ |
| Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. |
| """ |
|
|
| def __call__(self, pic): |
| """ |
| Args: |
| pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. |
| Returns: |
| Tensor: Converted image. |
| """ |
| if isinstance(pic, torch.Tensor): |
| return pic |
| return super().__call__(pic) |
|
|
|
|
| |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
| def make_normalize_transform( |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, |
| ) -> transforms.Normalize: |
| return transforms.Normalize(mean=mean, std=std) |
|
|
|
|
| |
| |
| def make_classification_train_transform( |
| *, |
| crop_size: int = 224, |
| interpolation=transforms.InterpolationMode.BICUBIC, |
| hflip_prob: float = 0.5, |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, |
| ): |
| transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] |
| if hflip_prob > 0.0: |
| transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) |
| transforms_list.extend( |
| [ |
| MaybeToTensor(), |
| make_normalize_transform(mean=mean, std=std), |
| ] |
| ) |
| return transforms.Compose(transforms_list) |
|
|
|
|
| |
| |
| def make_classification_eval_transform( |
| *, |
| resize_size: int = 256, |
| interpolation=transforms.InterpolationMode.BICUBIC, |
| crop_size: int = 224, |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, |
| ) -> transforms.Compose: |
| transforms_list = [ |
| transforms.Resize(resize_size, interpolation=interpolation), |
| transforms.CenterCrop(crop_size), |
| MaybeToTensor(), |
| make_normalize_transform(mean=mean, std=std), |
| ] |
| return transforms.Compose(transforms_list) |
|
|