from torchvision.transforms import ( Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop, ) def _convert_to_rgb(image): return image.convert("RGB") def image_transform( image_size: int, is_train: bool, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ): normalize = Normalize(mean=mean, std=std) if is_train: return Compose( [ RandomResizedCrop( image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ToTensor(), normalize, ] ) else: return Compose( [ Resize(image_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), _convert_to_rgb, ToTensor(), normalize, ] )