| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import timm |
| from timm.data import create_transform |
|
|
| from yacs.config import CfgNode as CN |
| from PIL import ImageFilter |
| import logging |
| import random |
|
|
| import torch |
| import torchvision.transforms as T |
|
|
|
|
| from .autoaugment import AutoAugmentPolicy |
| from .autoaugment import AutoAugment |
| from .autoaugment import RandAugment |
| from .autoaugment import TrivialAugmentWide |
| from .threeaugment import deitIII_Solarization |
| from .threeaugment import deitIII_gray_scale |
| from .threeaugment import deitIII_GaussianBlur |
|
|
| from PIL import ImageOps |
| from timm.data.transforms import RandomResizedCropAndInterpolation |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GaussianBlur(object): |
| """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" |
|
|
| def __init__(self, sigma=[.1, 2.]): |
| self.sigma = sigma |
|
|
| def __call__(self, x): |
| sigma = random.uniform(self.sigma[0], self.sigma[1]) |
| x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) |
| return x |
|
|
|
|
| def get_resolution(original_resolution): |
| """Takes (H,W) and returns (precrop, crop).""" |
| area = original_resolution[0] * original_resolution[1] |
| return (160, 128) if area < 96*96 else (512, 480) |
|
|
|
|
| INTERPOLATION_MODES = { |
| 'bilinear': T.InterpolationMode.BILINEAR, |
| 'bicubic': T.InterpolationMode.BICUBIC, |
| 'nearest': T.InterpolationMode.NEAREST, |
| } |
|
|
|
|
| def build_transforms(cfg, is_train=True): |
| |
| normalize = T.Normalize( |
| mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| std=cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| ) |
|
|
| transforms = None |
| if is_train: |
| if 'THREE_AUG' in cfg['AUG']: |
| img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'] |
| remove_random_resized_crop = cfg['AUG']['THREE_AUG']['SRC'] |
| mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| primary_tfl = [] |
| scale=(0.08, 1.0) |
| interpolation='bicubic' |
| if remove_random_resized_crop: |
| primary_tfl = [ |
| T.Resize(img_size, interpolation=3), |
| T.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| T.RandomHorizontalFlip() |
| ] |
| else: |
| primary_tfl = [ |
| RandomResizedCropAndInterpolation( |
| img_size, scale=scale, interpolation=interpolation), |
| T.RandomHorizontalFlip() |
| ] |
| secondary_tfl = [T.RandomChoice([gray_scale(p=1.0), |
| Solarization(p=1.0), |
| GaussianBlurDeiTv3(p=1.0)])] |
| color_jitter = cfg['AUG']['THREE_AUG']['COLOR_JITTER'] |
| if color_jitter is not None and not color_jitter==0: |
| secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter)) |
| final_tfl = [ |
| T.ToTensor(), |
| T.Normalize( |
| mean=torch.tensor(mean), |
| std=torch.tensor(std)) |
| ] |
| return T.Compose(primary_tfl+secondary_tfl+final_tfl) |
| elif 'TIMM_AUG' in cfg['AUG'] and cfg['AUG']['TIMM_AUG']['USE_TRANSFORM']: |
| logger.info('=> use timm transform for training') |
| timm_cfg = cfg['AUG']['TIMM_AUG'] |
| transforms = create_transform( |
| input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| is_training=True, |
| use_prefetcher=False, |
| no_aug=False, |
| re_prob=timm_cfg.get('RE_PROB', 0.), |
| re_mode=timm_cfg.get('RE_MODE', 'const'), |
| re_count=timm_cfg.get('RE_COUNT', 1), |
| re_num_splits= 0 if not timm_cfg.get('RE_SPLITS', False) else timm_cfg['RE_SPLITS'], |
| scale=cfg['AUG'].get('SCALE', None), |
| ratio=cfg['AUG'].get('RATIO', None), |
| hflip=timm_cfg.get('HFLIP', 0.5), |
| vflip=timm_cfg.get('VFLIP', 0.), |
| color_jitter=timm_cfg.get('COLOR_JITTER', 0.4), |
| auto_augment=timm_cfg.get('AUTO_AUGMENT', None), |
| interpolation=cfg['AUG']['INTERPOLATION'], |
| mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| std=cfg['IMAGE_ENCODER']['IMAGE_STD'], |
| ) |
| elif 'TORCHVISION_AUG' in cfg['AUG']: |
| logger.info('=> use torchvision transform fro training') |
| crop_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| interpolation = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| trans = [ |
| T.RandomResizedCrop( |
| crop_size, scale=cfg['AUG']['SCALE'], ratio=cfg['AUG']['RATIO'], |
| interpolation=interpolation |
| ) |
| ] |
| hflip_prob = cfg['AUG']['TORCHVISION_AUG']['HFLIP'] |
| auto_augment_policy = cfg['AUG']['TORCHVISION_AUG'].get('AUTO_AUGMENT', None) |
| if hflip_prob > 0: |
| trans.append(T.RandomHorizontalFlip(hflip_prob)) |
| if auto_augment_policy is not None: |
| if auto_augment_policy == "ra": |
| trans.append(RandAugment(interpolation=interpolation)) |
| elif auto_augment_policy == "ta_wide": |
| trans.append(TrivialAugmentWide(interpolation=interpolation)) |
| else: |
| aa_policy = AutoAugmentPolicy(auto_augment_policy) |
| trans.append(AutoAugment(policy=aa_policy, interpolation=interpolation)) |
| trans.extend( |
| [ |
| T.ToTensor(), |
| normalize, |
| ] |
| ) |
| random_erase_prob = cfg['AUG']['TORCHVISION_AUG']['RE_PROB'] |
| random_erase_scale = cfg['AUG']['TORCHVISION_AUG'].get('RE_SCALE', 0.33) |
| if random_erase_prob > 0: |
| |
| trans.append(T.RandomErasing(p=random_erase_prob, scale = (0.02, random_erase_scale))) |
|
|
| from torchvision.transforms import InterpolationMode |
| rotation = cfg['AUG']['TORCHVISION_AUG'].get('ROTATION', 0.0) |
| if (rotation > 0.0): |
| trans.append(T.RandomRotation(rotation, interpolation=InterpolationMode.BILINEAR)) |
| logger.info(" TORCH AUG: Rotation: " + str(rotation)) |
|
|
| transforms = T.Compose(trans) |
| elif cfg['AUG'].get('RANDOM_CENTER_CROP', False): |
| logger.info('=> use random center crop data augmenation') |
| |
| crop = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| padding = cfg['AUG'].get('RANDOM_CENTER_CROP_PADDING', 32) |
| precrop = crop + padding |
| mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| transforms = T.Compose([ |
| T.Resize( |
| (precrop, precrop), |
| interpolation=mode |
| ), |
| T.RandomCrop((crop, crop)), |
| T.RandomHorizontalFlip(), |
| T.ToTensor(), |
| normalize, |
| ]) |
| elif cfg['AUG'].get('MAE_FINETUNE_AUG', False): |
| mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| transforms = create_transform( |
| input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| is_training=True, |
| color_jitter=cfg['AUG'].get('COLOR_JITTER', None), |
| auto_augment=cfg['AUG'].get('AUTO_AUGMENT', 'rand-m9-mstd0.5-inc1'), |
| interpolation='bicubic', |
| re_prob=cfg['AUG'].get('RE_PROB', 0.25), |
| re_mode=cfg['AUG'].get('RE_MODE', "pixel"), |
| re_count=cfg['AUG'].get('RE_COUNT', 1), |
| mean=mean, |
| std=std, |
| ) |
| elif cfg['AUG'].get('MAE_PRETRAIN_AUG', False): |
| mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| transforms = T.Compose([ |
| T.RandomResizedCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], scale=tuple(cfg['AUG']['SCALE']), interpolation=INTERPOLATION_MODES["bicubic"]), |
| T.RandomHorizontalFlip(), |
| T.ToTensor(), |
| T.Normalize(mean=mean, std=std)]) |
| elif cfg['AUG'].get('ThreeAugment', False): |
| mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| remove_random_resized_crop = cfg['AUG'].get('src', False) |
| mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| primary_tfl = [] |
| scale=(0.08, 1.0) |
| interpolation='bicubic' |
| if remove_random_resized_crop: |
| primary_tfl = [ |
| T.Resize(img_size, interpolation=3), |
| T.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| T.RandomHorizontalFlip() |
| ] |
| else: |
| primary_tfl = [ |
| timm.data.transforms.RandomResizedCropAndInterpolation( |
| img_size, scale=scale, interpolation=interpolation), |
| T.RandomHorizontalFlip() |
| ] |
|
|
| secondary_tfl = [T.RandomChoice([deitIII_gray_scale(p=1.0), |
| deitIII_Solarization(p=1.0), |
| deitIII_GaussianBlur(p=1.0)])] |
| color_jitter = cfg['AUG']['COLOR_JITTER'] |
| secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter)) |
| final_tfl = [ |
| T.ToTensor(), |
| T.Normalize( |
| mean=torch.tensor(mean), |
| std=torch.tensor(std)) |
| ] |
| transforms = T.Compose(primary_tfl+secondary_tfl+final_tfl) |
| logger.info('=> training transformers: {}'.format(transforms)) |
| else: |
| mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| if cfg['TEST']['CENTER_CROP']: |
| transforms = T.Compose([ |
| T.Resize( |
| int(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] / 0.875), |
| |
| |
| interpolation=mode |
| ), |
| T.CenterCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]), |
| T.ToTensor(), |
| normalize, |
| ]) |
| else: |
| transforms = T.Compose([ |
| T.Resize( |
| (cfg['IMAGE_ENCODER']['IMAGE_SIZE'][1], cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]), |
| interpolation=mode |
| ), |
| T.ToTensor(), |
| normalize, |
| ]) |
| logger.info('=> testing transformers: {}'.format(transforms)) |
|
|
| return transforms |
|
|
|
|