Guess-What-Moves / datasets /flow_pair_detectron.py
subhc's picture
Code Commit
5e88f62
import math
from pathlib import Path
import random
import detectron2.data.transforms as DT
import einops
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from detectron2.data import detection_utils as d2_utils
from detectron2.structures import Instances, BitMasks
from torch.utils.data import Dataset
from utils.data import read_flow, read_flo
def load_flow_tensor(path, resize=None, normalize=True, align_corners=True):
"""
Load flow, scale the pixel values according to the resized scale.
If normalize is true, return rescaled in normalized pixel coordinates
where pixel coordinates are in range [-1, 1].
NOTE: RAFT USES ALIGN_CORNERS=TRUE SO WE NEED TO ACCOUNT FOR THIS
Returns (2, H, W) float32
"""
flow = read_flo(path).astype(np.float32)
H, W, _ = flow.shape
h, w = (H, W) if resize is None else resize
u, v = flow[..., 0], flow[..., 1]
if normalize:
if align_corners:
u = 2.0 * u / (W - 1)
v = 2.0 * v / (H - 1)
else:
u = 2.0 * u / W
v = 2.0 * v / H
else:
h, w = resize
u = w * u / W
v = h * v / H
if h != H or w !=W:
u = Image.fromarray(u).resize((w, h), Image.ANTIALIAS)
v = Image.fromarray(v).resize((w, h), Image.ANTIALIAS)
u, v = np.array(u), np.array(v)
return torch.from_numpy(np.stack([u, v], axis=0))
class FlowPairDetectron(Dataset):
def __init__(self, data_dir, resolution, to_rgb=False, size_divisibility=None, enable_photo_aug=False, flow_clip=1., norm=True, read_big=True, force1080p=False, flow_res=None):
self.eval = eval
self.to_rgb = to_rgb
self.data_dir = data_dir
self.flow_dir = {k: [e for e in v if e.shape[0] > 0] for k, v in data_dir[0].items()}
self.flow_dir = {k: v for k, v in self.flow_dir.items() if len(v) > 0}
self.resolution = resolution
self.size_divisibility = size_divisibility
self.ignore_label = -1
self.transforms = DT.AugmentationList([
DT.Resize(self.resolution, interp=Image.BICUBIC),
])
self.photometric_aug = T.Compose([
T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)]),
p=0.8),
T.RandomGrayscale(p=0.2),
]) if enable_photo_aug else None
self.flow_clip=flow_clip
self.norm_flow=norm
self.read_big = read_big
self.force1080p_transforms = None
if force1080p:
self.force1080p_transforms = DT.AugmentationList([
DT.Resize((1088, 1920), interp=Image.BICUBIC),
])
self.big_flow_resolution = flow_res
def __len__(self):
return sum([cat.shape[0] for cat in next(iter(self.flow_dir.values()))]) if len(
self.flow_dir.values()) > 0 else 0
def __getitem__(self, idx):
dataset_dicts = []
random_gap = random.choice(list(self.flow_dir.keys()))
flowgaps = self.flow_dir[random_gap]
vid = random.choice(flowgaps)
flos = random.choice(vid)
dataset_dict = {}
fname = Path(flos[0]).stem
dname = Path(flos[0]).parent.name
suffix = '.png' if 'CLEVR' in fname else '.jpg'
rgb_dir = (self.data_dir[1] / dname / fname).with_suffix(suffix)
gt_dir = (self.data_dir[2] / dname / fname).with_suffix('.png')
flo0 = einops.rearrange(read_flow(str(flos[0]), self.resolution, self.to_rgb), 'c h w -> h w c')
flo1 = einops.rearrange(read_flow(str(flos[1]), self.resolution, self.to_rgb), 'c h w -> h w c')
if self.big_flow_resolution is not None:
flo0_big = einops.rearrange(read_flow(str(flos[0]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
flo1_big = einops.rearrange(read_flow(str(flos[1]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c')
rgb = d2_utils.read_image(rgb_dir).astype(np.float32)
original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float()
if self.read_big:
rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32)
rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.)
if self.force1080p_transforms is not None:
rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0]
# print('not here', rgb.min(), rgb.max())
input = DT.AugInput(rgb)
# Apply the augmentation:
preprocessing_transforms = self.transforms(input) # type: DT.Transform
rgb = input.image
if self.photometric_aug:
rgb_aug = Image.fromarray(rgb.astype(np.uint8))
rgb_aug = self.photometric_aug(rgb_aug)
rgb_aug = d2_utils.convert_PIL_to_numpy(rgb_aug, 'RGB')
rgb_aug = np.transpose(rgb_aug, (2, 0, 1)).astype(np.float32)
rgb = np.transpose(rgb, (2, 0, 1))
rgb = rgb.clip(0., 255.)
# print('here', rgb.min(), rgb.max())
d2_utils.check_image_size(dataset_dict, flo0)
if gt_dir.exists():
sem_seg_gt = d2_utils.read_image(str(gt_dir))
sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt)
# sem_seg_gt = cv2.resize(sem_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
if sem_seg_gt.ndim == 3:
sem_seg_gt = sem_seg_gt[:, :, 0]
if sem_seg_gt.max() == 255:
sem_seg_gt = (sem_seg_gt > 128).astype(int)
else:
sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1]))
gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / dname / fname).with_suffix('.png')
if gwm_dir.exists():
gwm_seg_gt = d2_utils.read_image(str(gwm_dir))
gwm_seg_gt = preprocessing_transforms.apply_segmentation(gwm_seg_gt)
gwm_seg_gt = np.array(gwm_seg_gt)
# gwm_seg_gt = cv2.resize(gwm_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST)
if gwm_seg_gt.ndim == 3:
gwm_seg_gt = gwm_seg_gt[:, :, 0]
if gwm_seg_gt.max() == 255:
gwm_seg_gt[gwm_seg_gt == 255] = 1
else:
gwm_seg_gt = None
if sem_seg_gt is None:
raise ValueError(
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
dataset_dict["file_name"]
)
)
# Pad image and segmentation label here!
if self.to_rgb:
flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) / 2 + .5
flo0 = flo0 * 255
flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) / 2 + .5
flo1 = flo1 * 255
if self.big_flow_resolution is not None:
flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) / 2 + .5
flo0_big = flo0_big * 255
flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) / 2 + .5
flo1_big = flo1_big * 255
else:
flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1)))
flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1)))
if self.norm_flow:
flo0 = flo0 / (flo0 ** 2).sum(0).max().sqrt()
flo1 = flo1 / (flo1 ** 2).sum(0).max().sqrt()
flo0 = flo0.clip(-self.flow_clip, self.flow_clip)
flo1 = flo1.clip(-self.flow_clip, self.flow_clip)
if self.big_flow_resolution is not None:
flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1)))
flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1)))
if self.norm_flow:
flo0_big = flo0_big / (flo0_big ** 2).sum(0).max().sqrt()
flo1_big = flo1_big / (flo1_big ** 2).sum(0).max().sqrt()
flo0_big = flo0_big.clip(-self.flow_clip, self.flow_clip)
flo1_big = flo1_big.clip(-self.flow_clip, self.flow_clip)
rgb = torch.as_tensor(np.ascontiguousarray(rgb))
if self.photometric_aug:
rgb_aug = torch.as_tensor(np.ascontiguousarray(rgb_aug))
if sem_seg_gt is not None:
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
if gwm_seg_gt is not None:
gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long"))
if self.size_divisibility > 0:
image_size = (flo0.shape[-2], flo0.shape[-1])
padding_size = [
0,
int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1],
0,
int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0],
]
flo0 = F.pad(flo0, padding_size, value=0).contiguous()
flo1 = F.pad(flo1, padding_size, value=0).contiguous()
rgb = F.pad(rgb, padding_size, value=128).contiguous()
if self.photometric_aug:
rgb_aug = F.pad(rgb_aug, padding_size, value=128).contiguous()
if sem_seg_gt is not None:
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
if gwm_seg_gt is not None:
gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous()
image_shape = (rgb.shape[-2], rgb.shape[-1]) # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["flow"] = flo0
dataset_dict["flow_2"] = flo1
# dataset_dict["flow_fwd"] = flo_norm_fwd
# dataset_dict["flow_bwd"] = flo_norm_bwd
# dataset_dict["flow_rgb"] = rgb_flo0
# dataset_dict["flow_gap"] = gap
dataset_dict["rgb"] = rgb
dataset_dict["original_rgb"] = original_rgb
if self.read_big:
dataset_dict["RGB_BIG"] = rgb_big
if self.photometric_aug:
dataset_dict["rgb_aug"] = rgb_aug
if self.big_flow_resolution is not None:
dataset_dict["flow_big"] = flo0_big
dataset_dict["flow_big_2"] = flo1_big
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = sem_seg_gt.long()
if gwm_seg_gt is not None:
dataset_dict["gwm_seg"] = gwm_seg_gt.long()
if "annotations" in dataset_dict:
raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
# Prepare per-category binary masks
if sem_seg_gt is not None:
sem_seg_gt = sem_seg_gt.numpy()
instances = Instances(image_shape)
classes = np.unique(sem_seg_gt)
# remove ignored region
classes = classes[classes != self.ignore_label]
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
masks = []
for class_id in classes:
masks.append(sem_seg_gt == class_id)
if len(masks) == 0:
# Some image does not have annotation (all ignored)
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
else:
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
)
instances.gt_masks = masks.tensor
dataset_dict["instances"] = instances
dataset_dicts.append(dataset_dict)
return dataset_dicts