MultiMAE / utils /dataset_regression.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame
No virus
5.19 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on BEiT, timm, DINO, DeiT and MAE-priv code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/BUPT-PRIV/MAE-priv
# --------------------------------------------------------
import numpy as np
import torch
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
except:
print('albumentations not installed')
# import cv2
import torch.nn.functional as F
from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, NYU_MEAN,
NYU_STD, PAD_MASK_VALUE)
from utils.dataset_folder import ImageFolder, MultiTaskImageFolder
def nyu_transform(train, additional_targets, input_size=512, color_aug=False):
if train:
augs = [
A.SmallestMaxSize(max_size=input_size, p=1),
A.HorizontalFlip(p=0.5),
]
if color_aug: augs += [
# Color jittering from BYOL https://arxiv.org/pdf/2006.07733.pdf
A.ColorJitter(
brightness=0.1255,
contrast=0.4,
saturation=[0.5, 1.5],
hue=[-0.2, 0.2],
p=0.5
),
A.ToGray(p=0.3),
]
augs += [
A.RandomCrop(height=input_size, width=input_size, p=1),
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
ToTensorV2(),
]
transform = A.Compose(augs, additional_targets=additional_targets)
else:
transform = A.Compose([
A.SmallestMaxSize(max_size=input_size, p=1),
A.CenterCrop(height=input_size, width=input_size),
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
ToTensorV2(),
], additional_targets=additional_targets)
return transform
def simple_regression_transform(train, additional_targets, input_size=512, pad_value=(128, 128, 128), pad_mask_value=PAD_MASK_VALUE):
if train:
transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.LongestMaxSize(max_size=input_size, p=1),
A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5), # Color jittering from MoCo-v3 / DINO
A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1), # This is LSJ (0.1, 2.0)
A.PadIfNeeded(min_height=input_size, min_width=input_size,
position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
border_mode=cv2.BORDER_CONSTANT,
value=pad_value, mask_value=pad_mask_value),
A.RandomCrop(height=input_size, width=input_size, p=1),
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
ToTensorV2(),
], additional_targets=additional_targets)
else:
transform = A.Compose([
A.LongestMaxSize(max_size=input_size, p=1),
A.PadIfNeeded(min_height=input_size, min_width=input_size,
position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
border_mode=cv2.BORDER_CONSTANT,
value=pad_value, mask_value=pad_mask_value),
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
ToTensorV2(),
], additional_targets=additional_targets)
return transform
class DataAugmentationForRegression(object):
def __init__(self, transform, mask_value=0.0):
self.transform = transform
self.mask_value = mask_value
def __call__(self, task_dict):
# Need to replace rgb key to image
task_dict['image'] = task_dict.pop('rgb')
# Convert to np.array
task_dict = {k: np.array(v) for k, v in task_dict.items()}
task_dict = self.transform(**task_dict)
task_dict['depth'] = (task_dict['depth'].float() - NYU_MEAN)/NYU_STD
# And then replace it back to rgb
task_dict['rgb'] = task_dict.pop('image')
task_dict['mask_valid'] = (task_dict['mask_valid'] == 255)[None]
for task in task_dict:
if task in ['depth']:
img = task_dict[task]
if 'mask_valid' in task_dict:
mask_valid = task_dict['mask_valid'].squeeze()
img[~mask_valid] = self.mask_value
task_dict[task] = img.unsqueeze(0)
elif task in ['rgb']:
task_dict[task] = task_dict[task].to(torch.float)
return task_dict
def build_regression_dataset(args, data_path, transform, max_images=None):
transform = DataAugmentationForRegression(transform=transform)
return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=None, max_images=max_images)