| import PIL | |
| import torchvision.transforms as T | |
| def pair(t): | |
| return t if isinstance(t, tuple) else (t, t) | |
| def stage1_transform(img_size=256, is_train=True, scale=0.8): | |
| resize = pair(int(img_size/scale)) | |
| t = [] | |
| t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) | |
| if is_train: | |
| t.append(T.RandomCrop(img_size)) | |
| t.append(T.RandomHorizontalFlip(p=0.5)) | |
| else: | |
| t.append(T.CenterCrop(img_size)) | |
| t.append(T.ToTensor()) | |
| t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))), | |
| return T.Compose(t) | |
| def stage2_transform(img_size=256, is_train=True, scale=0.8): | |
| resize = pair(int(img_size/scale)) | |
| t = [] | |
| t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) | |
| if is_train: | |
| t.append(T.RandomCrop(img_size)) | |
| else: | |
| t.append(T.CenterCrop(img_size)) | |
| t.append(T.ToTensor()) | |
| t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))), | |
| return T.Compose(t) | |