Spaces:
Runtime error
Runtime error
# Copyright (c) EPFL VILAB. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# Based on BEiT, timm, DINO, DeiT and MAE-priv code bases | |
# https://github.com/microsoft/unilm/tree/master/beit | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit | |
# https://github.com/facebookresearch/dino | |
# https://github.com/BUPT-PRIV/MAE-priv | |
# -------------------------------------------------------- | |
import numpy as np | |
import torch | |
try: | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
except: | |
print('albumentations not installed') | |
# import cv2 | |
import torch.nn.functional as F | |
from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, NYU_MEAN, | |
NYU_STD, PAD_MASK_VALUE) | |
from utils.dataset_folder import ImageFolder, MultiTaskImageFolder | |
def nyu_transform(train, additional_targets, input_size=512, color_aug=False): | |
if train: | |
augs = [ | |
A.SmallestMaxSize(max_size=input_size, p=1), | |
A.HorizontalFlip(p=0.5), | |
] | |
if color_aug: augs += [ | |
# Color jittering from BYOL https://arxiv.org/pdf/2006.07733.pdf | |
A.ColorJitter( | |
brightness=0.1255, | |
contrast=0.4, | |
saturation=[0.5, 1.5], | |
hue=[-0.2, 0.2], | |
p=0.5 | |
), | |
A.ToGray(p=0.3), | |
] | |
augs += [ | |
A.RandomCrop(height=input_size, width=input_size, p=1), | |
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
ToTensorV2(), | |
] | |
transform = A.Compose(augs, additional_targets=additional_targets) | |
else: | |
transform = A.Compose([ | |
A.SmallestMaxSize(max_size=input_size, p=1), | |
A.CenterCrop(height=input_size, width=input_size), | |
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
ToTensorV2(), | |
], additional_targets=additional_targets) | |
return transform | |
def simple_regression_transform(train, additional_targets, input_size=512, pad_value=(128, 128, 128), pad_mask_value=PAD_MASK_VALUE): | |
if train: | |
transform = A.Compose([ | |
A.HorizontalFlip(p=0.5), | |
A.LongestMaxSize(max_size=input_size, p=1), | |
A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5), # Color jittering from MoCo-v3 / DINO | |
A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1), # This is LSJ (0.1, 2.0) | |
A.PadIfNeeded(min_height=input_size, min_width=input_size, | |
position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, | |
border_mode=cv2.BORDER_CONSTANT, | |
value=pad_value, mask_value=pad_mask_value), | |
A.RandomCrop(height=input_size, width=input_size, p=1), | |
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
ToTensorV2(), | |
], additional_targets=additional_targets) | |
else: | |
transform = A.Compose([ | |
A.LongestMaxSize(max_size=input_size, p=1), | |
A.PadIfNeeded(min_height=input_size, min_width=input_size, | |
position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, | |
border_mode=cv2.BORDER_CONSTANT, | |
value=pad_value, mask_value=pad_mask_value), | |
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
ToTensorV2(), | |
], additional_targets=additional_targets) | |
return transform | |
class DataAugmentationForRegression(object): | |
def __init__(self, transform, mask_value=0.0): | |
self.transform = transform | |
self.mask_value = mask_value | |
def __call__(self, task_dict): | |
# Need to replace rgb key to image | |
task_dict['image'] = task_dict.pop('rgb') | |
# Convert to np.array | |
task_dict = {k: np.array(v) for k, v in task_dict.items()} | |
task_dict = self.transform(**task_dict) | |
task_dict['depth'] = (task_dict['depth'].float() - NYU_MEAN)/NYU_STD | |
# And then replace it back to rgb | |
task_dict['rgb'] = task_dict.pop('image') | |
task_dict['mask_valid'] = (task_dict['mask_valid'] == 255)[None] | |
for task in task_dict: | |
if task in ['depth']: | |
img = task_dict[task] | |
if 'mask_valid' in task_dict: | |
mask_valid = task_dict['mask_valid'].squeeze() | |
img[~mask_valid] = self.mask_value | |
task_dict[task] = img.unsqueeze(0) | |
elif task in ['rgb']: | |
task_dict[task] = task_dict[task].to(torch.float) | |
return task_dict | |
def build_regression_dataset(args, data_path, transform, max_images=None): | |
transform = DataAugmentationForRegression(transform=transform) | |
return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=None, max_images=max_images) | |