import torch import torchvision.transforms as transforms import numpy as np import torch.nn.functional as F class RGBDAugmentor: """ perform augmentation on RGB-D video """ def __init__(self, crop_size): self.crop_size = crop_size self.augcolor = transforms.Compose([ transforms.ToPILImage(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2/3.14), transforms.RandomGrayscale(p=0.1), transforms.RandomInvert(p=0.1), transforms.ToTensor()]) self.max_scale = 0.5 def spatial_transform(self, images, depths, poses, intrinsics): """ cropping and resizing """ ht, wd = images.shape[2:] max_scale = self.max_scale min_scale = np.log2(np.maximum( (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd))) scale = 1 if np.random.rand() < 0.8: scale = 2 ** np.random.uniform(0.0, max_scale) intrinsics = scale * intrinsics ht1 = int(scale * ht) wd1 = int(scale * wd) depths = depths.unsqueeze(dim=1) images = F.interpolate(images, (ht1, wd1), mode='bicubic', align_corners=False) depths = F.interpolate(depths, (ht1, wd1), recompute_scale_factor=False) # always perform center crop (TODO: try non-center crops) y0 = (images.shape[2] - self.crop_size[0]) // 2 x0 = (images.shape[3] - self.crop_size[1]) // 2 intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0]) images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] depths = depths.squeeze(dim=1) return images, poses, depths, intrinsics def color_transform(self, images): """ color jittering """ num, ch, ht, wd = images.shape images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num) images = 255 * self.augcolor(images[[2,1,0]] / 255.0) return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous() def __call__(self, images, poses, depths, intrinsics): if np.random.rand() < 0.5: images = self.color_transform(images) return self.spatial_transform(images, depths, poses, intrinsics)