Spaces:
Runtime error
Runtime error
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from .task_configs import task_parameters | |
MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan) | |
RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans | |
MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0]) | |
RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255]) | |
MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0]) | |
STD_IMAGENET = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
# For semantic segmentation | |
transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing | |
# Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1] | |
transform_8bit = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
# Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels | |
def transform_8bit_n_channel(n_channel=1, crop_channels=True): | |
if crop_channels: | |
crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x | |
else: | |
crop_channels_fn = lambda x: x | |
return transforms.Compose([ | |
transforms.ToTensor(), | |
crop_channels_fn, | |
]) | |
# Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1]. | |
def transform_16bit_single_channel(im): | |
im = transforms.ToTensor()(np.array(im)) | |
im = im.float() / (2 ** 16 - 1.0) | |
return im | |
def make_valid_mask(mask_float, max_pool_size=4): | |
''' | |
Creates a mask indicating the valid parts of the image(s). | |
Enlargens masked area using a max pooling operation. | |
Args: | |
mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader. | |
max_pool_size: Parameter to choose how much to enlarge masked area. | |
''' | |
squeeze = False | |
if len(mask_float.shape) == 3: | |
mask_float = mask_float.unsqueeze(0) | |
squeeze = True | |
_, _, h, w = mask_float.shape | |
mask_float = 1 - mask_float | |
mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size) | |
mask_float = F.interpolate(mask_float, (h, w), mode='nearest') | |
mask_valid = mask_float == 0 | |
mask_valid = mask_valid[0] if squeeze else mask_valid | |
return mask_valid | |
def task_transform(file, task: str, image_size=Optional[int]): | |
transform = None | |
if task in ['rgb']: | |
transform = transforms.Compose([ | |
transform_8bit, | |
STD_IMAGENET | |
]) | |
elif task in ['normal']: | |
transform = transform_8bit | |
elif task in ['mask_valid']: | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
make_valid_mask | |
]) | |
elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture']: | |
transform = transform_16bit_single_channel | |
elif task in ['edge_occlusion']: | |
transform = transforms.Compose([ | |
transform_16bit_single_channel, | |
transforms.GaussianBlur(3, sigma=1) | |
]) | |
elif task in ['principal_curvature', 'curvature']: | |
transform = transform_8bit_n_channel(2) | |
elif task in ['reshading']: | |
transform = transform_8bit_n_channel(1) | |
elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class | |
transform = transform_dense_labels | |
elif task in ['class_object', 'class_scene']: | |
transform = torch.Tensor | |
image_size = None | |
else: | |
transform = None | |
if 'threshold_min' in task_parameters[task]: | |
threshold = task_parameters[task]['threshold_min'] | |
transform = transforms.Compose([ | |
transform, | |
lambda x: torch.threshold(x, threshold, 0.0) | |
]) | |
if 'clamp_to' in task_parameters[task]: | |
minn, maxx = task_parameters[task]['clamp_to'] | |
if minn > 0: | |
raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task)) | |
transform = transforms.Compose([ | |
transform, | |
lambda x: torch.clamp(x, minn, maxx), | |
MAKE_RESCALE_0_MAX_0_POS1(maxx) | |
]) | |
if image_size is not None: | |
if task == 'fragments': | |
resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0) | |
transform = transforms.Compose([ | |
transform, | |
resize_frag | |
]) | |
else: | |
resize_method = transforms.InterpolationMode.BILINEAR if task in ['rgb'] else transforms.InterpolationMode.NEAREST | |
transform = transforms.Compose([ | |
transforms.Resize(image_size, resize_method), | |
transform | |
]) | |
if transform is not None: | |
file = transform(file) | |
return file | |