|
from torchvision.transforms import (
|
|
Normalize,
|
|
Compose,
|
|
RandomResizedCrop,
|
|
InterpolationMode,
|
|
ToTensor,
|
|
Resize,
|
|
CenterCrop,
|
|
)
|
|
|
|
|
|
def _convert_to_rgb(image):
|
|
return image.convert("RGB")
|
|
|
|
|
|
def image_transform(
|
|
image_size: int,
|
|
is_train: bool,
|
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
|
std=(0.26862954, 0.26130258, 0.27577711),
|
|
):
|
|
normalize = Normalize(mean=mean, std=std)
|
|
if is_train:
|
|
return Compose(
|
|
[
|
|
RandomResizedCrop(
|
|
image_size,
|
|
scale=(0.9, 1.0),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
_convert_to_rgb,
|
|
ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
else:
|
|
return Compose(
|
|
[
|
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
|
CenterCrop(image_size),
|
|
_convert_to_rgb,
|
|
ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
|