# Copyright (c) EPFL VILAB. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Based on BEiT, timm, DINO, DeiT and MAE-priv code bases # https://github.com/microsoft/unilm/tree/master/beit # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit # https://github.com/facebookresearch/dino # https://github.com/BUPT-PRIV/MAE-priv # -------------------------------------------------------- from typing import Dict, Tuple import numpy as np import torch try: import albumentations as A from albumentations.pytorch import ToTensorV2 except: print('albumentations not installed') import cv2 import torch.nn.functional as F from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PAD_MASK_VALUE, SEG_IGNORE_INDEX) from .dataset_folder import ImageFolder, MultiTaskImageFolder def simple_transform(train: bool, additional_targets: Dict[str, str], input_size: int =512, pad_value: Tuple[int, int, int] = (128, 128, 128), pad_mask_value: int =PAD_MASK_VALUE): """Default transform for semantic segmentation, applied on all modalities During training: 1. Random horizontal Flip 2. Rescaling so that longest side matches input size 3. Color jitter (for RGB-modality only) 4. Large scale jitter (LSJ) 5. Padding 6. Random crop to given size 7. Normalization with ImageNet mean and std dev During validation / test: 1. Rescaling so that longest side matches given size 2. Padding 3. Normalization with ImageNet mean and std dev """ if train: transform = A.Compose([ A.HorizontalFlip(p=0.5), A.LongestMaxSize(max_size=input_size, p=1), A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5), # Color jittering from MoCo-v3 / DINO A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1), # This is LSJ (0.1, 2.0) A.PadIfNeeded(min_height=input_size, min_width=input_size, position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, border_mode=cv2.BORDER_CONSTANT, value=pad_value, mask_value=pad_mask_value), A.RandomCrop(height=input_size, width=input_size, p=1), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ], additional_targets=additional_targets) else: transform = A.Compose([ A.LongestMaxSize(max_size=input_size, p=1), A.PadIfNeeded(min_height=input_size, min_width=input_size, position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, border_mode=cv2.BORDER_CONSTANT, value=pad_value, mask_value=pad_mask_value), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ], additional_targets=additional_targets) return transform class DataAugmentationForSemSeg(object): """Data transform / augmentation for semantic segmentation downstream tasks. """ def __init__(self, transform, seg_num_classes, seg_ignore_index=SEG_IGNORE_INDEX, standardize_depth=True, seg_reduce_zero_label=False, seg_use_void_label=False): self.transform = transform self.seg_num_classes = seg_num_classes self.seg_ignore_index = seg_ignore_index self.standardize_depth = standardize_depth self.seg_reduce_zero_label = seg_reduce_zero_label self.seg_use_void_label = seg_use_void_label @staticmethod def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): img[img == PAD_MASK_VALUE] = torch.nan if mask_valid is not None: # This is if we want to apply masking before standardization img[~mask_valid] = torch.nan sorted_img = torch.sort(torch.flatten(img))[0] # Remove nan, nan at the end of sort num_nan = sorted_img.isnan().sum() if num_nan > 0: sorted_img = sorted_img[:-num_nan] # Remove outliers trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] trunc_mean = trunc_img.mean() trunc_var = trunc_img.var() eps = 1e-6 # Replace nan by mean img = torch.nan_to_num(img, nan=trunc_mean) # Standardize img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) return img def seg_adapt_labels(self, img): if self.seg_use_void_label: # Set void label to num_classes if self.seg_reduce_zero_label: pad_replace = self.seg_num_classes + 1 else: pad_replace = self.seg_num_classes else: pad_replace = self.seg_ignore_index img[img == PAD_MASK_VALUE] = pad_replace if self.seg_reduce_zero_label: img[img == 0] = self.seg_ignore_index img = img - 1 img[img == self.seg_ignore_index - 1] = self.seg_ignore_index return img def __call__(self, task_dict): # Need to replace rgb key to image task_dict['image'] = task_dict.pop('rgb') # Convert to np.array task_dict = {k: np.array(v) for k, v in task_dict.items()} task_dict = self.transform(**task_dict) # And then replace it back to rgb task_dict['rgb'] = task_dict.pop('image') for task in task_dict: if task in ['depth']: img = task_dict[task].to(torch.float) if self.standardize_depth: # Mask valid set to None here, as masking is applied after standardization img = self.standardize_depth_map(img, mask_valid=None) if 'mask_valid' in task_dict: mask_valid = (task_dict['mask_valid'] == 255).squeeze() img[~mask_valid] = 0.0 task_dict[task] = img.unsqueeze(0) elif task in ['rgb']: task_dict[task] = task_dict[task].to(torch.float) elif task in ['semseg']: img = task_dict[task].to(torch.long) img = self.seg_adapt_labels(img) task_dict[task] = img elif task in ['pseudo_semseg']: # If it's pseudo-semseg, then it's an input modality and should therefore be resized img = task_dict[task] img = F.interpolate(img[None,None,:,:], scale_factor=0.25, mode='nearest').long()[0,0] task_dict[task] = img return task_dict def build_semseg_dataset(args, data_path, transform, max_images=None): transform = DataAugmentationForSemSeg(transform=transform, seg_num_classes=args.num_classes, standardize_depth=args.standardize_depth, seg_reduce_zero_label=args.seg_reduce_zero_label, seg_use_void_label=args.seg_use_void_label) prefixes = {'depth': 'pseudo_'} if args.load_pseudo_depth else None return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=prefixes, max_images=max_images) def ade_classes(): """ADE20K class names for external use.""" return [ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag' ] def hypersim_classes(): """Hypersim class names for external use.""" return [ 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor-mat', 'clothes', 'ceiling', 'books', 'fridge', 'TV', 'paper', 'towel', 'shower-curtain', 'box', 'white-board', 'person', 'night-stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'other-struct', 'other-furntr', 'other-prop' ] def nyu_v2_40_classes(): """NYUv2 40 class names for external use.""" return [ 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor-mat', 'clothes', 'ceiling', 'books', 'fridge', 'TV', 'paper', 'towel', 'shower-curtain', 'box', 'white-board', 'person', 'night-stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'other-struct', 'other-furntr', 'other-prop' ]