import logging import os import subprocess from functools import lru_cache from pathlib import Path import cv2 import einops import numpy as np import torch from cvbase.optflow.visualize import flow2rgb from detectron2.data import detection_utils as d2_utils __LOGGER = logging.Logger(__name__) __TAR_SP = [Path('/usr/bin/tar'), Path('/bin/tar')] TAG_FLOAT = 202021.25 def read_flo(file): assert type(file) is str, "file is not str %r" % str(file) assert os.path.isfile(file) is True, "file does not exist %r" % str(file) assert file[-4:] == '.flo', "file ending is not .flo %r" % file[-4:] f = open(file, 'rb') flo_number = np.fromfile(f, np.float32, count=1)[0] assert flo_number == TAG_FLOAT, 'Flow number %r incorrect. Invalid .flo file' % flo_number w = np.fromfile(f, np.int32, count=1) h = np.fromfile(f, np.int32, count=1) data = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) # Reshape data into 3D array (columns, rows, bands) flow = np.resize(data, (int(h), int(w), 2)) f.close() return flow def read_flow(sample_dir, resolution=None, to_rgb=False): flow = read_flo(sample_dir) h, w, _ = np.shape(flow) if resolution: flow = cv2.resize(flow, (resolution[1], resolution[0]), interpolation=cv2.INTER_NEAREST) flow[:, :, 0] = flow[:, :, 0] * resolution[1] / w flow[:, :, 1] = flow[:, :, 1] * resolution[0] / h if to_rgb: flow = np.clip((flow2rgb(flow) - 0.5) * 2, -1., 1.) return einops.rearrange(flow, 'h w c -> c h w') def read_rgb(sample_dir, resolution=None): rgb = d2_utils.read_image(sample_dir) rgb = ((rgb / 255.0) - 0.5) * 2.0 if resolution: rgb = cv2.resize(rgb, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR) rgb = np.clip(rgb, -1., 1.) return einops.rearrange(rgb, 'h w c -> c h w') ### from: https://github.com/pytorch/pytorch/issues/15849#issuecomment-518126031 class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) # https://github.com/pytorch/pytorch/issues/15849#issuecomment-573921048 class FastDataLoader(torch.utils.data.dataloader.DataLoader): '''for reusing cpu workers, to save time''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) # self.batch_sampler = _RepeatSampler(self.batch_sampler) self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) # Originally written by wkentaro # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py