import torchvision.transforms as T from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD pre_process = T.Compose( [ T.ToPILImage(), T.Resize( size=(224, 224), interpolation=T.InterpolationMode.BICUBIC, antialias=True, ), T.ToTensor(), T.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) def pre_process_foo(img_size: tuple, dataset: str = "laion") -> T.Compose: return T.Compose( [ T.ToPILImage(), T.Resize( size=img_size, interpolation=T.InterpolationMode.BICUBIC, antialias=True, ), T.ToTensor(), T.Normalize( mean=(0.48145466, 0.4578275, 0.40821073) if dataset != "imagenet" else IMAGENET_DEFAULT_MEAN, std=(0.26862954, 0.26130258, 0.27577711) if dataset != "imagenet" else IMAGENET_DEFAULT_STD, ), ] )