| from torchvision.transforms import ( | |
| Normalize, | |
| Compose, | |
| RandomResizedCrop, | |
| InterpolationMode, | |
| ToTensor, | |
| Resize, | |
| CenterCrop, | |
| ) | |
| def _convert_to_rgb(image): | |
| return image.convert("RGB") | |
| def image_transform( | |
| image_size: int, | |
| is_train: bool, | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ): | |
| normalize = Normalize(mean=mean, std=std) | |
| if is_train: | |
| return Compose( | |
| [ | |
| RandomResizedCrop( | |
| image_size, | |
| scale=(0.9, 1.0), | |
| interpolation=InterpolationMode.BICUBIC, | |
| ), | |
| _convert_to_rgb, | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| else: | |
| return Compose( | |
| [ | |
| Resize(image_size, interpolation=InterpolationMode.BICUBIC), | |
| CenterCrop(image_size), | |
| _convert_to_rgb, | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
