import os import torch import numpy as np from pyquaternion import Quaternion from nuscenes.nuscenes import NuScenes from itertools import chain from PIL import Image from torchvision import transforms as T import torchvision.transforms as tvf from torchvision.transforms.functional import to_tensor from .splits_roddick import create_splits_scenes_roddick from ..image import pad_image, rectify_image, resize_image from .utils import decode_binary_labels from ..utils import decompose_rotmat from ...utils.io import read_image from ...utils.wrappers import Camera from ..schema import NuScenesDataConfiguration class NuScenesDataset(torch.utils.data.Dataset): def __init__(self, cfg: NuScenesDataConfiguration, split="train"): self.cfg = cfg self.nusc = NuScenes(version=cfg.version, dataroot=str(cfg.data_dir)) self.map_data_root = cfg.map_dir self.split = split self.scenes = create_splits_scenes_roddick() # custom based on Roddick et al. scene_split = { 'v1.0-trainval': {'train': 'train', 'val': 'val', 'test': 'val'}, 'v1.0-mini': {'train': 'mini_train', 'val': 'mini_val'}, }[cfg.version][split] self.scenes = self.scenes[scene_split] self.sample = list(filter(lambda sample: self.nusc.get( 'scene', sample['scene_token'])['name'] in self.scenes, self.nusc.sample)) self.tfs = self.get_augmentations() if split == "train" else T.Compose([]) data_tokens = [] for sample in self.sample: data_token = sample['data'] data_token = [v for k,v in data_token.items() if k == "CAM_FRONT"] data_tokens.append(data_token) data_tokens = list(chain.from_iterable(data_tokens)) data = [self.nusc.get('sample_data', token) for token in data_tokens] self.data = [] for d in data: sample = self.nusc.get('sample', d['sample_token']) scene = self.nusc.get('scene', sample['scene_token']) location = self.nusc.get('log', scene['log_token'])['location'] file_name = d['filename'] ego_pose = self.nusc.get('ego_pose', d['ego_pose_token']) calibrated_sensor = self.nusc.get( "calibrated_sensor", d['calibrated_sensor_token']) ego2global = np.eye(4).astype(np.float32) ego2global[:3, :3] = Quaternion(ego_pose['rotation']).rotation_matrix ego2global[:3, 3] = ego_pose['translation'] sensor2ego = np.eye(4).astype(np.float32) sensor2ego[:3, :3] = Quaternion( calibrated_sensor['rotation']).rotation_matrix sensor2ego[:3, 3] = calibrated_sensor['translation'] sensor2global = ego2global @ sensor2ego rotation = sensor2global[:3, :3] roll, pitch, yaw = decompose_rotmat(rotation) fx = calibrated_sensor['camera_intrinsic'][0][0] fy = calibrated_sensor['camera_intrinsic'][1][1] cx = calibrated_sensor['camera_intrinsic'][0][2] cy = calibrated_sensor['camera_intrinsic'][1][2] width = d['width'] height = d['height'] cam = Camera(torch.tensor( [width, height, fx, fy, cx - 0.5, cy - 0.5])).float() self.data.append({ 'filename': file_name, 'yaw': yaw, 'pitch': pitch, 'roll': roll, 'cam': cam, 'sensor2global': sensor2global, 'token': d['token'], 'sample_token': d['sample_token'], 'location': location }) if self.cfg.percentage < 1.0 and split == "train": self.data = self.data[:int(len(self.data) * self.cfg.percentage)] def get_augmentations(self): print(f"Augmentation!", "\n" * 10) augmentations = [ tvf.ColorJitter( brightness=self.cfg.augmentations.brightness, contrast=self.cfg.augmentations.contrast, saturation=self.cfg.augmentations.saturation, hue=self.cfg.augmentations.hue, ) ] if self.cfg.augmentations.random_resized_crop: augmentations.append( tvf.RandomResizedCrop(scale=(0.8, 1.0)) ) # RandomResizedCrop if self.cfg.augmentations.gaussian_noise.enabled: augmentations.append( tvf.GaussianNoise( mean=self.cfg.augmentations.gaussian_noise.mean, std=self.cfg.augmentations.gaussian_noise.std, ) ) # Gaussian noise if self.cfg.augmentations.brightness_contrast.enabled: augmentations.append( tvf.ColorJitter( brightness=self.cfg.augmentations.brightness_contrast.brightness_factor, contrast=self.cfg.augmentations.brightness_contrast.contrast_factor, saturation=0, # Keep saturation at 0 for brightness and contrast adjustment hue=0, ) ) # Brightness and contrast adjustment return tvf.Compose(augmentations) def __len__(self): return len(self.data) def __getitem__(self, idx): d = self.data[idx] image = read_image(os.path.join(self.nusc.dataroot, d['filename'])) image = np.array(image) cam = d['cam'] roll = d['roll'] pitch = d['pitch'] yaw = d['yaw'] with Image.open(self.map_data_root / f"{d['token']}.png") as semantic_image: semantic_mask = to_tensor(semantic_image) semantic_mask = decode_binary_labels(semantic_mask, self.cfg.num_classes + 1) semantic_mask = torch.nn.functional.max_pool2d(semantic_mask.float(), (2, 2), stride=2) # 2 times downsample semantic_mask = semantic_mask.permute(1, 2, 0) semantic_mask = torch.flip(semantic_mask, [0]) visibility_mask = semantic_mask[..., -1] semantic_mask = semantic_mask[..., :-1] if self.cfg.class_mapping is not None: semantic_mask = semantic_mask[..., self.cfg.class_mapping] image = ( torch.from_numpy(np.ascontiguousarray(image)) .permute(2, 0, 1) .float() .div_(255) ) if not self.cfg.gravity_align: # Turn off gravity alignment roll = 0.0 pitch = 0.0 image, valid = rectify_image(image, cam, roll, pitch) else: image, valid = rectify_image( image, cam, roll, pitch if self.cfg.rectify_pitch else None ) roll = 0.0 if self.cfg.rectify_pitch: pitch = 0.0 if self.cfg.resize_image is not None: image, _, cam, valid = resize_image( image, self.cfg.resize_image, fn=max, camera=cam, valid=valid ) if self.cfg.pad_to_square: image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid) image = self.tfs(image) confidence_map = visibility_mask.clone().float() confidence_map = (confidence_map - confidence_map.min()) / (confidence_map.max() - confidence_map.min()) return { "image": image, "roll_pitch_yaw": torch.tensor([roll, pitch, yaw]).float(), "camera": cam, "valid": valid, "seg_masks": semantic_mask.float(), "token": d['token'], "sample_token": d['sample_token'], 'location': d['location'], 'flood_masks': visibility_mask.float(), "confidence_map": confidence_map, 'name': d['sample_token'] }