Spaces:
Runtime error
Runtime error
import math | |
import os | |
from pathlib import Path | |
import detectron2.data.transforms as DT | |
import einops | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from detectron2.data import detection_utils as d2_utils | |
from detectron2.structures import Instances, BitMasks | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import Dataset | |
from utils.data import read_flow | |
class FlowEvalDetectron(Dataset): | |
def __init__(self, data_dir, resolution, pair_list, val_seq, to_rgb=False, with_rgb=False, size_divisibility=None, | |
small_val=0, flow_clip=1., norm=True, read_big=True, eval_size=True, force1080p=False): | |
self.val_seq = val_seq | |
self.to_rgb = to_rgb | |
self.with_rgb = with_rgb | |
self.data_dir = data_dir | |
self.pair_list = pair_list | |
self.resolution = resolution | |
self.eval_size = eval_size | |
self.samples = [] | |
self.samples_fid = {} | |
for v in self.val_seq: | |
seq_dir = Path(self.data_dir[0]) / v | |
frames_paths = sorted(seq_dir.glob('*.flo')) | |
self.samples_fid[str(seq_dir)] = {fp: i for i, fp in enumerate(frames_paths)} | |
self.samples.extend(frames_paths) | |
self.samples = [os.path.join(x.parent.name, x.name) for x in self.samples] | |
if small_val > 0: | |
_, self.samples = train_test_split(self.samples, test_size=small_val, random_state=42) | |
self.gaps = ['gap{}'.format(i) for i in pair_list] | |
self.neg_gaps = ['gap{}'.format(-i) for i in pair_list] | |
self.size_divisibility = size_divisibility | |
self.ignore_label = -1 | |
self.transforms = DT.AugmentationList([ | |
DT.Resize(self.resolution, interp=Image.BICUBIC), | |
]) | |
self.flow_clip=flow_clip | |
self.norm_flow=norm | |
self.read_big=read_big | |
self.force1080p_transforms=None | |
if force1080p: | |
self.force1080p_transforms = DT.AugmentationList([ | |
DT.Resize((1088, 1920), interp=Image.BICUBIC), | |
]) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
dataset_dicts = [] | |
dataset_dict = {} | |
flow_dir = Path(self.data_dir[0]) / self.samples[idx] | |
fid = self.samples_fid[str(flow_dir.parent)][flow_dir] | |
flo = einops.rearrange(read_flow(str(flow_dir), self.resolution, self.to_rgb), 'c h w -> h w c') | |
dataset_dict["gap"] = 'gap1' | |
suffix = '.png' if 'CLEVR' in self.samples[idx] else '.jpg' | |
rgb_dir = (self.data_dir[1] / self.samples[idx]).with_suffix(suffix) | |
gt_dir = (self.data_dir[2] / self.samples[idx]).with_suffix('.png') | |
rgb = d2_utils.read_image(str(rgb_dir)).astype(np.float32) | |
original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float() | |
if self.read_big: | |
rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32) | |
rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.) | |
if self.force1080p_transforms is not None: | |
rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0] | |
input = DT.AugInput(rgb) | |
# Apply the augmentation: | |
preprocessing_transforms = self.transforms(input) # type: DT.Transform | |
rgb = input.image | |
rgb = np.transpose(rgb, (2, 0, 1)) | |
rgb = rgb.clip(0., 255.) | |
d2_utils.check_image_size(dataset_dict, flo) | |
if gt_dir.exists(): | |
sem_seg_gt_ori = d2_utils.read_image(gt_dir) | |
sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt_ori) | |
if sem_seg_gt.ndim == 3: | |
sem_seg_gt = sem_seg_gt[:, :, 0] | |
sem_seg_gt_ori = sem_seg_gt_ori[:, :, 0] | |
if sem_seg_gt.max() == 255: | |
sem_seg_gt = (sem_seg_gt > 128).astype(int) | |
sem_seg_gt_ori = (sem_seg_gt_ori > 128).astype(int) | |
else: | |
sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1])) | |
sem_seg_gt_ori = np.zeros((original_rgb.shape[-2], original_rgb.shape[-1])) | |
gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / self.samples[idx]).with_suffix( | |
'.png') | |
if gwm_dir.exists(): | |
gwm_seg_gt = preprocessing_transforms.apply_segmentation(d2_utils.read_image(str(gwm_dir))) | |
gwm_seg_gt = np.array(gwm_seg_gt) | |
if gwm_seg_gt.ndim == 3: | |
gwm_seg_gt = gwm_seg_gt[:, :, 0] | |
if gwm_seg_gt.max() == 255: | |
gwm_seg_gt[gwm_seg_gt == 255] = 1 | |
else: | |
gwm_seg_gt = None | |
if sem_seg_gt is None: | |
raise ValueError( | |
"Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( | |
dataset_dict["file_name"] | |
) | |
) | |
# Pad image and segmentation label here! | |
if self.to_rgb: | |
flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) / 2 + .5 | |
flo = flo * 255 | |
else: | |
flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) | |
if self.norm_flow: | |
flo = flo/(flo ** 2).sum(0).max().sqrt() | |
flo = flo.clip(-self.flow_clip, self.flow_clip) | |
rgb = torch.as_tensor(np.ascontiguousarray(rgb)).float() | |
if sem_seg_gt is not None: | |
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) | |
sem_seg_gt_ori = torch.as_tensor(sem_seg_gt_ori.astype("long")) | |
if gwm_seg_gt is not None: | |
gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long")) | |
if self.size_divisibility > 0: | |
image_size = (flo.shape[-2], flo.shape[-1]) | |
padding_size = [ | |
0, | |
int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1], | |
0, | |
int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0], | |
] | |
flo = F.pad(flo, padding_size, value=0).contiguous() | |
rgb = F.pad(rgb, padding_size, value=128).contiguous() | |
if sem_seg_gt is not None: | |
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() | |
if gwm_seg_gt is not None: | |
gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous() | |
image_shape = (flo.shape[-2], flo.shape[-1]) # h, w | |
if self.eval_size: | |
image_shape = (sem_seg_gt_ori.shape[-2], sem_seg_gt_ori.shape[-1]) | |
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, | |
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. | |
# Therefore it's important to use torch.Tensor. | |
dataset_dict["flow"] = flo | |
dataset_dict["rgb"] = rgb | |
dataset_dict["original_rgb"] = F.interpolate(original_rgb[None], mode='bicubic', size=sem_seg_gt_ori.shape[-2:], align_corners=False).clip(0.,255.)[0] | |
if self.read_big: | |
dataset_dict["RGB_BIG"] = rgb_big | |
dataset_dict["category"] = str(gt_dir).split('/')[-2:] | |
dataset_dict['frame_id'] = fid | |
if sem_seg_gt is not None: | |
dataset_dict["sem_seg"] = sem_seg_gt.long() | |
dataset_dict["sem_seg_ori"] = sem_seg_gt_ori.long() | |
if gwm_seg_gt is not None: | |
dataset_dict["gwm_seg"] = gwm_seg_gt.long() | |
if "annotations" in dataset_dict: | |
raise ValueError("Semantic segmentation dataset should not have 'annotations'.") | |
# Prepare per-category binary masks | |
if sem_seg_gt is not None: | |
sem_seg_gt = sem_seg_gt.numpy() | |
instances = Instances(image_shape) | |
classes = np.unique(sem_seg_gt) | |
# remove ignored region | |
classes = classes[classes != self.ignore_label] | |
instances.gt_classes = torch.tensor(classes, dtype=torch.int64) | |
masks = [] | |
for class_id in classes: | |
masks.append(sem_seg_gt == class_id) | |
if len(masks) == 0: | |
# Some image does not have annotation (all ignored) | |
instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) | |
else: | |
masks = BitMasks( | |
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) | |
) | |
instances.gt_masks = masks.tensor | |
dataset_dict["instances"] = instances | |
dataset_dicts.append(dataset_dict) | |
return dataset_dicts | |