jwyang
first commit
4121bec
raw
history blame
No virus
3.29 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# from . import transforms as T
import torchvision.transforms as T
from PIL import Image
from timm.data import create_transform
from .torchvision_transforms.transforms import Resize as New_Resize
def build_clip_transforms(cfg, is_train=True):
if cfg.AUG.USE_TIMM and is_train:
print('=> use timm transform for training')
timm_cfg = cfg.AUG.TIMM_AUG
transforms = create_transform(
input_size=cfg.TRAIN.IMAGE_SIZE[0],
is_training=True,
use_prefetcher=False,
no_aug=False,
re_prob=timm_cfg.RE_PROB,
re_mode=timm_cfg.RE_MODE,
re_count=timm_cfg.RE_COUNT,
scale=cfg.AUG.SCALE,
ratio=cfg.AUG.RATIO,
hflip=timm_cfg.HFLIP,
vflip=timm_cfg.VFLIP,
color_jitter=timm_cfg.COLOR_JITTER,
auto_augment=timm_cfg.AUTO_AUGMENT,
interpolation=timm_cfg.INTERPOLATION,
mean=cfg.MODEL.PIXEL_MEAN,
std=cfg.MODEL.PIXEL_STD,
)
return transforms
# normalize_transform = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
# NOTE: normalization is applied in rcnn.py, to keep consistent as Detectron2
# normalize = T.Normalize(mean=cfg.MODEL.PIXEL_MEAN, std=cfg.MODEL.PIXEL_STD) # T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
transforms = None
if is_train:
aug = cfg.AUG
scale = aug.SCALE
ratio = aug.RATIO
if len(cfg.AUG.TRAIN.IMAGE_SIZE) == 2: # Data Augmentation from MSR-CLIP
ts = [
T.RandomResizedCrop(
cfg.AUG.TRAIN.IMAGE_SIZE[0], scale=scale, ratio=ratio,
interpolation=cfg.AUG.INTERPOLATION
),
T.RandomHorizontalFlip(),
]
elif len(cfg.AUG.TRAIN.IMAGE_SIZE) == 1 and cfg.AUG.TRAIN.MAX_SIZE is not None: # designed for pretraining fastrcnn
ts = [
New_Resize(
cfg.AUG.TRAIN.IMAGE_SIZE[0], max_size=cfg.AUG.TRAIN.MAX_SIZE,
interpolation=cfg.AUG.INTERPOLATION
),
T.RandomHorizontalFlip(),
]
cj = aug.COLOR_JITTER
if cj[-1] > 0.0:
ts.append(T.RandomApply([T.ColorJitter(*cj[:-1])], p=cj[-1]))
gs = aug.GRAY_SCALE
if gs > 0.0:
ts.append(T.RandomGrayscale(gs))
gb = aug.GAUSSIAN_BLUR
if gb > 0.0:
ts.append(T.RandomApply([GaussianBlur([.1, 2.])], p=gb))
ts.append(T.ToTensor())
# NOTE: normalization is applied in rcnn.py, to keep consistent as Detectron2
#ts.append(normalize)
transforms = T.Compose(ts)
else:
# for zeroshot inference of grounding evaluation
transforms = T.Compose([
T.Resize(
cfg.AUG.TEST.IMAGE_SIZE[0],
interpolation=cfg.AUG.TEST.INTERPOLATION
),
T.ToTensor(),
])
return transforms
return transforms