import math from pathlib import Path import random import detectron2.data.transforms as DT import einops import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as T from PIL import Image from detectron2.data import detection_utils as d2_utils from detectron2.structures import Instances, BitMasks from torch.utils.data import Dataset from utils.data import read_flow, read_flo def load_flow_tensor(path, resize=None, normalize=True, align_corners=True): """ Load flow, scale the pixel values according to the resized scale. If normalize is true, return rescaled in normalized pixel coordinates where pixel coordinates are in range [-1, 1]. NOTE: RAFT USES ALIGN_CORNERS=TRUE SO WE NEED TO ACCOUNT FOR THIS Returns (2, H, W) float32 """ flow = read_flo(path).astype(np.float32) H, W, _ = flow.shape h, w = (H, W) if resize is None else resize u, v = flow[..., 0], flow[..., 1] if normalize: if align_corners: u = 2.0 * u / (W - 1) v = 2.0 * v / (H - 1) else: u = 2.0 * u / W v = 2.0 * v / H else: h, w = resize u = w * u / W v = h * v / H if h != H or w !=W: u = Image.fromarray(u).resize((w, h), Image.ANTIALIAS) v = Image.fromarray(v).resize((w, h), Image.ANTIALIAS) u, v = np.array(u), np.array(v) return torch.from_numpy(np.stack([u, v], axis=0)) class FlowPairDetectron(Dataset): def __init__(self, data_dir, resolution, to_rgb=False, size_divisibility=None, enable_photo_aug=False, flow_clip=1., norm=True, read_big=True, force1080p=False, flow_res=None): self.eval = eval self.to_rgb = to_rgb self.data_dir = data_dir self.flow_dir = {k: [e for e in v if e.shape[0] > 0] for k, v in data_dir[0].items()} self.flow_dir = {k: v for k, v in self.flow_dir.items() if len(v) > 0} self.resolution = resolution self.size_divisibility = size_divisibility self.ignore_label = -1 self.transforms = DT.AugmentationList([ DT.Resize(self.resolution, interp=Image.BICUBIC), ]) self.photometric_aug = T.Compose([ T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)]), p=0.8), T.RandomGrayscale(p=0.2), ]) if enable_photo_aug else None self.flow_clip=flow_clip self.norm_flow=norm self.read_big = read_big self.force1080p_transforms = None if force1080p: self.force1080p_transforms = DT.AugmentationList([ DT.Resize((1088, 1920), interp=Image.BICUBIC), ]) self.big_flow_resolution = flow_res def __len__(self): return sum([cat.shape[0] for cat in next(iter(self.flow_dir.values()))]) if len( self.flow_dir.values()) > 0 else 0 def __getitem__(self, idx): dataset_dicts = [] random_gap = random.choice(list(self.flow_dir.keys())) flowgaps = self.flow_dir[random_gap] vid = random.choice(flowgaps) flos = random.choice(vid) dataset_dict = {} fname = Path(flos[0]).stem dname = Path(flos[0]).parent.name suffix = '.png' if 'CLEVR' in fname else '.jpg' rgb_dir = (self.data_dir[1] / dname / fname).with_suffix(suffix) gt_dir = (self.data_dir[2] / dname / fname).with_suffix('.png') flo0 = einops.rearrange(read_flow(str(flos[0]), self.resolution, self.to_rgb), 'c h w -> h w c') flo1 = einops.rearrange(read_flow(str(flos[1]), self.resolution, self.to_rgb), 'c h w -> h w c') if self.big_flow_resolution is not None: flo0_big = einops.rearrange(read_flow(str(flos[0]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c') flo1_big = einops.rearrange(read_flow(str(flos[1]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c') rgb = d2_utils.read_image(rgb_dir).astype(np.float32) original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float() if self.read_big: rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32) rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.) if self.force1080p_transforms is not None: rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0] # print('not here', rgb.min(), rgb.max()) input = DT.AugInput(rgb) # Apply the augmentation: preprocessing_transforms = self.transforms(input) # type: DT.Transform rgb = input.image if self.photometric_aug: rgb_aug = Image.fromarray(rgb.astype(np.uint8)) rgb_aug = self.photometric_aug(rgb_aug) rgb_aug = d2_utils.convert_PIL_to_numpy(rgb_aug, 'RGB') rgb_aug = np.transpose(rgb_aug, (2, 0, 1)).astype(np.float32) rgb = np.transpose(rgb, (2, 0, 1)) rgb = rgb.clip(0., 255.) # print('here', rgb.min(), rgb.max()) d2_utils.check_image_size(dataset_dict, flo0) if gt_dir.exists(): sem_seg_gt = d2_utils.read_image(str(gt_dir)) sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt) # sem_seg_gt = cv2.resize(sem_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST) if sem_seg_gt.ndim == 3: sem_seg_gt = sem_seg_gt[:, :, 0] if sem_seg_gt.max() == 255: sem_seg_gt = (sem_seg_gt > 128).astype(int) else: sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1])) gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / dname / fname).with_suffix('.png') if gwm_dir.exists(): gwm_seg_gt = d2_utils.read_image(str(gwm_dir)) gwm_seg_gt = preprocessing_transforms.apply_segmentation(gwm_seg_gt) gwm_seg_gt = np.array(gwm_seg_gt) # gwm_seg_gt = cv2.resize(gwm_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST) if gwm_seg_gt.ndim == 3: gwm_seg_gt = gwm_seg_gt[:, :, 0] if gwm_seg_gt.max() == 255: gwm_seg_gt[gwm_seg_gt == 255] = 1 else: gwm_seg_gt = None if sem_seg_gt is None: raise ValueError( "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( dataset_dict["file_name"] ) ) # Pad image and segmentation label here! if self.to_rgb: flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) / 2 + .5 flo0 = flo0 * 255 flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) / 2 + .5 flo1 = flo1 * 255 if self.big_flow_resolution is not None: flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) / 2 + .5 flo0_big = flo0_big * 255 flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) / 2 + .5 flo1_big = flo1_big * 255 else: flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) if self.norm_flow: flo0 = flo0 / (flo0 ** 2).sum(0).max().sqrt() flo1 = flo1 / (flo1 ** 2).sum(0).max().sqrt() flo0 = flo0.clip(-self.flow_clip, self.flow_clip) flo1 = flo1.clip(-self.flow_clip, self.flow_clip) if self.big_flow_resolution is not None: flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) if self.norm_flow: flo0_big = flo0_big / (flo0_big ** 2).sum(0).max().sqrt() flo1_big = flo1_big / (flo1_big ** 2).sum(0).max().sqrt() flo0_big = flo0_big.clip(-self.flow_clip, self.flow_clip) flo1_big = flo1_big.clip(-self.flow_clip, self.flow_clip) rgb = torch.as_tensor(np.ascontiguousarray(rgb)) if self.photometric_aug: rgb_aug = torch.as_tensor(np.ascontiguousarray(rgb_aug)) if sem_seg_gt is not None: sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) if gwm_seg_gt is not None: gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long")) if self.size_divisibility > 0: image_size = (flo0.shape[-2], flo0.shape[-1]) padding_size = [ 0, int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1], 0, int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0], ] flo0 = F.pad(flo0, padding_size, value=0).contiguous() flo1 = F.pad(flo1, padding_size, value=0).contiguous() rgb = F.pad(rgb, padding_size, value=128).contiguous() if self.photometric_aug: rgb_aug = F.pad(rgb_aug, padding_size, value=128).contiguous() if sem_seg_gt is not None: sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() if gwm_seg_gt is not None: gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous() image_shape = (rgb.shape[-2], rgb.shape[-1]) # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["flow"] = flo0 dataset_dict["flow_2"] = flo1 # dataset_dict["flow_fwd"] = flo_norm_fwd # dataset_dict["flow_bwd"] = flo_norm_bwd # dataset_dict["flow_rgb"] = rgb_flo0 # dataset_dict["flow_gap"] = gap dataset_dict["rgb"] = rgb dataset_dict["original_rgb"] = original_rgb if self.read_big: dataset_dict["RGB_BIG"] = rgb_big if self.photometric_aug: dataset_dict["rgb_aug"] = rgb_aug if self.big_flow_resolution is not None: dataset_dict["flow_big"] = flo0_big dataset_dict["flow_big_2"] = flo1_big if sem_seg_gt is not None: dataset_dict["sem_seg"] = sem_seg_gt.long() if gwm_seg_gt is not None: dataset_dict["gwm_seg"] = gwm_seg_gt.long() if "annotations" in dataset_dict: raise ValueError("Semantic segmentation dataset should not have 'annotations'.") # Prepare per-category binary masks if sem_seg_gt is not None: sem_seg_gt = sem_seg_gt.numpy() instances = Instances(image_shape) classes = np.unique(sem_seg_gt) # remove ignored region classes = classes[classes != self.ignore_label] instances.gt_classes = torch.tensor(classes, dtype=torch.int64) masks = [] for class_id in classes: masks.append(sem_seg_gt == class_id) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor dataset_dict["instances"] = instances dataset_dicts.append(dataset_dict) return dataset_dicts