hca97's picture
initial commit
9093750
import torchvision.transforms as T
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
pre_process = T.Compose(
[
T.ToPILImage(),
T.Resize(
size=(224, 224),
interpolation=T.InterpolationMode.BICUBIC,
antialias=True,
),
T.ToTensor(),
T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
def pre_process_foo(img_size: tuple, dataset: str = "laion") -> T.Compose:
return T.Compose(
[
T.ToPILImage(),
T.Resize(
size=img_size,
interpolation=T.InterpolationMode.BICUBIC,
antialias=True,
),
T.ToTensor(),
T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073)
if dataset != "imagenet"
else IMAGENET_DEFAULT_MEAN,
std=(0.26862954, 0.26130258, 0.27577711)
if dataset != "imagenet"
else IMAGENET_DEFAULT_STD,
),
]
)