| | from copy import deepcopy
|
| | from pathlib import Path
|
| | from typing import Any, Dict, List
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import torch.utils.data as torchdata
|
| | import torchvision.transforms as tvf
|
| | from PIL import Image
|
| | from pathlib import Path
|
| |
|
| | from ...models.utils import deg2rad, rotmat2d
|
| | from ...utils.io import read_image
|
| | from ...utils.wrappers import Camera
|
| | from ..image import pad_image, rectify_image, resize_image
|
| | from ..utils import decompose_rotmat
|
| | from ..schema import MIADataConfiguration
|
| |
|
| |
|
| | class MapLocDataset(torchdata.Dataset):
|
| | def __init__(
|
| | self,
|
| | stage: str,
|
| | cfg: MIADataConfiguration,
|
| | names: List[str],
|
| | data: Dict[str, Any],
|
| | image_dirs: Dict[str, Path],
|
| | seg_mask_dirs: Dict[str, Path],
|
| | flood_masks_dirs: Dict[str, Path],
|
| | image_ext: str = "",
|
| | ):
|
| | self.stage = stage
|
| | self.cfg = deepcopy(cfg)
|
| | self.data = data
|
| | self.image_dirs = image_dirs
|
| | self.seg_mask_dirs = seg_mask_dirs
|
| | self.flood_masks_dirs = flood_masks_dirs
|
| | self.names = names
|
| | self.image_ext = image_ext
|
| |
|
| | tfs = []
|
| | self.tfs = tvf.Compose(tfs)
|
| | self.augmentations = self.get_augmentations()
|
| |
|
| | def __len__(self):
|
| | return len(self.names)
|
| |
|
| | def __getitem__(self, idx):
|
| | if self.stage == "train" and self.cfg.random:
|
| | seed = None
|
| | else:
|
| | seed = [self.cfg.seed, idx]
|
| | (seed,) = np.random.SeedSequence(seed).generate_state(1)
|
| |
|
| | scene, seq, name = self.names[idx]
|
| |
|
| | view = self.get_view(
|
| | idx, scene, seq, name, seed
|
| | )
|
| |
|
| | return view
|
| |
|
| | def get_augmentations(self):
|
| | if self.stage != "train" or not self.cfg.augmentations.enabled:
|
| | print(f"No Augmentation!", "\n" * 10)
|
| | self.cfg.augmentations.random_flip = 0.0
|
| | return tvf.Compose([])
|
| |
|
| | print(f"Augmentation!", "\n" * 10)
|
| | augmentations = [
|
| | tvf.ColorJitter(
|
| | brightness=self.cfg.augmentations.brightness,
|
| | contrast=self.cfg.augmentations.contrast,
|
| | saturation=self.cfg.augmentations.saturation,
|
| | hue=self.cfg.augmentations.hue,
|
| | )
|
| | ]
|
| |
|
| | if self.cfg.augmentations.random_resized_crop:
|
| | augmentations.append(
|
| | tvf.RandomResizedCrop(scale=(0.8, 1.0))
|
| | )
|
| |
|
| | if self.cfg.augmentations.gaussian_noise.enabled:
|
| | augmentations.append(
|
| | tvf.GaussianNoise(
|
| | mean=self.cfg.augmentations.gaussian_noise.mean,
|
| | std=self.cfg.augmentations.gaussian_noise.std,
|
| | )
|
| | )
|
| |
|
| | if self.cfg.augmentations.brightness_contrast.enabled:
|
| | augmentations.append(
|
| | tvf.ColorJitter(
|
| | brightness=self.cfg.augmentations.brightness_contrast.brightness_factor,
|
| | contrast=self.cfg.augmentations.brightness_contrast.contrast_factor,
|
| | saturation=0,
|
| | hue=0,
|
| | )
|
| | )
|
| |
|
| | return tvf.Compose(augmentations)
|
| |
|
| | def random_flip(self, image, cam, valid, seg_mask, flood_mask, conf_mask):
|
| | if torch.rand(1) < self.cfg.augmentations.random_flip:
|
| | image = torch.flip(image, [-1])
|
| | cam = cam.flip()
|
| | valid = torch.flip(valid, [-1])
|
| | seg_mask = torch.flip(seg_mask, [1])
|
| | flood_mask = torch.flip(flood_mask, [-1])
|
| | conf_mask = torch.flip(conf_mask, [-1])
|
| |
|
| | return image, cam, valid, seg_mask, flood_mask, conf_mask
|
| |
|
| | def get_view(self, idx, scene, seq, name, seed):
|
| | data = {
|
| | "index": idx,
|
| | "name": name,
|
| | "scene": scene,
|
| | "sequence": seq,
|
| | }
|
| | cam_dict = self.data["cameras"][scene][seq][self.data["camera_id"][idx]]
|
| | cam = Camera.from_dict(cam_dict).float()
|
| |
|
| | if "roll_pitch_yaw" in self.data:
|
| | roll, pitch, yaw = self.data["roll_pitch_yaw"][idx].numpy()
|
| | else:
|
| | roll, pitch, yaw = decompose_rotmat(
|
| | self.data["R_c2w"][idx].numpy())
|
| |
|
| | image = read_image(self.image_dirs[scene] / (name + self.image_ext))
|
| | image = Image.fromarray(image)
|
| | image = self.augmentations(image)
|
| | image = np.array(image)
|
| |
|
| | if "plane_params" in self.data:
|
| |
|
| | plane_w = self.data["plane_params"][idx]
|
| | data["ground_plane"] = torch.cat(
|
| | [rotmat2d(deg2rad(torch.tensor(yaw)))
|
| | @ plane_w[:2], plane_w[2:]]
|
| | )
|
| |
|
| | image, valid, cam, roll, pitch = self.process_image(
|
| | image, cam, roll, pitch, seed
|
| | )
|
| |
|
| | if "chunk_index" in self.data:
|
| | data["chunk_id"] = (scene, seq, self.data["chunk_index"][idx])
|
| |
|
| |
|
| | seg_mask_path = self.seg_mask_dirs[scene] / \
|
| | (name.split("_")[0] + ".npy")
|
| | seg_masks_ours = np.load(seg_mask_path)
|
| | mask_center = (
|
| | seg_masks_ours.shape[0] // 2, seg_masks_ours.shape[1] // 2)
|
| |
|
| | seg_masks_ours = seg_masks_ours[mask_center[0] -
|
| | 100:mask_center[0], mask_center[1] - 50: mask_center[1] + 50]
|
| |
|
| | if self.cfg.num_classes == 6:
|
| | seg_masks_ours = seg_masks_ours[..., [0, 1, 2, 4, 6, 7]]
|
| |
|
| | flood_mask_path = self.flood_masks_dirs[scene] / \
|
| | (name.split("_")[0] + ".npy")
|
| | flood_mask = np.load(flood_mask_path)
|
| |
|
| | flood_mask = flood_mask[mask_center[0]-100:mask_center[0],
|
| | mask_center[1] - 50: mask_center[1] + 50]
|
| |
|
| | confidence_map = flood_mask.copy()
|
| | confidence_map = (confidence_map - confidence_map.min()) / \
|
| | (confidence_map.max() - confidence_map.min() + 1e-6)
|
| |
|
| | seg_masks_ours = torch.from_numpy(seg_masks_ours).float()
|
| | flood_mask = torch.from_numpy(flood_mask).float()
|
| | confidence_map = torch.from_numpy(confidence_map).float()
|
| |
|
| |
|
| | with torch.random.fork_rng(devices=[]):
|
| | torch.manual_seed(seed)
|
| | image, cam, valid, seg_masks_ours, flood_mask, confidence_map = self.random_flip(
|
| | image, cam, valid, seg_masks_ours, flood_mask, confidence_map)
|
| |
|
| | return {
|
| | **data,
|
| | "image": image,
|
| | "valid": valid,
|
| | "camera": cam,
|
| | "seg_masks": seg_masks_ours,
|
| | "flood_masks": flood_mask,
|
| | "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
|
| | "confidence_map": confidence_map
|
| |
|
| | }
|
| |
|
| | def process_image(self, image, cam, roll, pitch, seed):
|
| | image = (
|
| | torch.from_numpy(np.ascontiguousarray(image))
|
| | .permute(2, 0, 1)
|
| | .float()
|
| | .div_(255)
|
| | )
|
| |
|
| | if not self.cfg.gravity_align:
|
| |
|
| | roll = 0.0
|
| | pitch = 0.0
|
| | image, valid = rectify_image(image, cam, roll, pitch)
|
| | else:
|
| | image, valid = rectify_image(
|
| | image, cam, roll, pitch if self.cfg.rectify_pitch else None
|
| | )
|
| | roll = 0.0
|
| | if self.cfg.rectify_pitch:
|
| | pitch = 0.0
|
| |
|
| | if self.cfg.target_focal_length is not None:
|
| |
|
| | factor = self.cfg.target_focal_length / cam.f.numpy()
|
| | size = (np.array(image.shape[-2:][::-1]) * factor).astype(int)
|
| | image, _, cam, valid = resize_image(
|
| | image, size, camera=cam, valid=valid)
|
| | size_out = self.cfg.resize_image
|
| | if size_out is None:
|
| |
|
| | stride = self.cfg.pad_to_multiple
|
| | size_out = (np.ceil((size / stride)) * stride).astype(int)
|
| |
|
| | image, valid, cam = pad_image(
|
| | image, size_out, cam, valid, crop_and_center=True
|
| | )
|
| | elif self.cfg.resize_image is not None:
|
| | image, _, cam, valid = resize_image(
|
| | image, self.cfg.resize_image, fn=max, camera=cam, valid=valid
|
| | )
|
| | if self.cfg.pad_to_square:
|
| |
|
| | image, valid, cam = pad_image(
|
| | image, self.cfg.resize_image, cam, valid)
|
| |
|
| | if self.cfg.reduce_fov is not None:
|
| | h, w = image.shape[-2:]
|
| | f = float(cam.f[0])
|
| | fov = np.arctan(w / f / 2)
|
| | w_new = round(2 * f * np.tan(self.cfg.reduce_fov * fov))
|
| | image, valid, cam = pad_image(
|
| | image, (w_new, h), cam, valid, crop_and_center=True
|
| | )
|
| |
|
| | with torch.random.fork_rng(devices=[]):
|
| | torch.manual_seed(seed)
|
| | image = self.tfs(image)
|
| |
|
| | return image, valid, cam, roll, pitch
|
| |
|