MultiMAE / utils /taskonomy /transforms.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame
5.2 kB
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from .task_configs import task_parameters
MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan)
RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans
MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0])
RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255])
MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0])
STD_IMAGENET = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# For semantic segmentation
transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing
# Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1]
transform_8bit = transforms.Compose([
transforms.ToTensor(),
])
# Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels
def transform_8bit_n_channel(n_channel=1, crop_channels=True):
if crop_channels:
crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x
else:
crop_channels_fn = lambda x: x
return transforms.Compose([
transforms.ToTensor(),
crop_channels_fn,
])
# Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1].
def transform_16bit_single_channel(im):
im = transforms.ToTensor()(np.array(im))
im = im.float() / (2 ** 16 - 1.0)
return im
def make_valid_mask(mask_float, max_pool_size=4):
'''
Creates a mask indicating the valid parts of the image(s).
Enlargens masked area using a max pooling operation.
Args:
mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader.
max_pool_size: Parameter to choose how much to enlarge masked area.
'''
squeeze = False
if len(mask_float.shape) == 3:
mask_float = mask_float.unsqueeze(0)
squeeze = True
_, _, h, w = mask_float.shape
mask_float = 1 - mask_float
mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size)
mask_float = F.interpolate(mask_float, (h, w), mode='nearest')
mask_valid = mask_float == 0
mask_valid = mask_valid[0] if squeeze else mask_valid
return mask_valid
def task_transform(file, task: str, image_size=Optional[int]):
transform = None
if task in ['rgb']:
transform = transforms.Compose([
transform_8bit,
STD_IMAGENET
])
elif task in ['normal']:
transform = transform_8bit
elif task in ['mask_valid']:
transform = transforms.Compose([
transforms.ToTensor(),
make_valid_mask
])
elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture']:
transform = transform_16bit_single_channel
elif task in ['edge_occlusion']:
transform = transforms.Compose([
transform_16bit_single_channel,
transforms.GaussianBlur(3, sigma=1)
])
elif task in ['principal_curvature', 'curvature']:
transform = transform_8bit_n_channel(2)
elif task in ['reshading']:
transform = transform_8bit_n_channel(1)
elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class
transform = transform_dense_labels
elif task in ['class_object', 'class_scene']:
transform = torch.Tensor
image_size = None
else:
transform = None
if 'threshold_min' in task_parameters[task]:
threshold = task_parameters[task]['threshold_min']
transform = transforms.Compose([
transform,
lambda x: torch.threshold(x, threshold, 0.0)
])
if 'clamp_to' in task_parameters[task]:
minn, maxx = task_parameters[task]['clamp_to']
if minn > 0:
raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task))
transform = transforms.Compose([
transform,
lambda x: torch.clamp(x, minn, maxx),
MAKE_RESCALE_0_MAX_0_POS1(maxx)
])
if image_size is not None:
if task == 'fragments':
resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0)
transform = transforms.Compose([
transform,
resize_frag
])
else:
resize_method = transforms.InterpolationMode.BILINEAR if task in ['rgb'] else transforms.InterpolationMode.NEAREST
transform = transforms.Compose([
transforms.Resize(image_size, resize_method),
transform
])
if transform is not None:
file = transform(file)
return file