import torch import torch.nn.functional as F all_times = [] class Timer: def __init__(self, name: str, enabled: bool = True): self.name = name self.enabled = enabled if self.enabled: self.start = torch.cuda.Event(enable_timing=True) self.end = torch.cuda.Event(enable_timing=True) def __enter__(self): if self.enabled: self.start.record() def __exit__(self, type, value, traceback): global all_times if self.enabled: self.end.record() torch.cuda.synchronize() elapsed = self.start.elapsed_time(self.end) all_times.append(elapsed) print(f"{self.name}: {elapsed:.2f}ms") def coords_grid(b, n, h, w, **kwargs): """coordinate grid""" x = torch.arange(0, w, dtype=torch.float, **kwargs) y = torch.arange(0, h, dtype=torch.float, **kwargs) coords = torch.stack(torch.meshgrid(y, x, indexing="ij")) return coords[[1, 0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1) def coords_grid_with_index(d, **kwargs): """coordinate grid with frame index""" b, n, h, w = d.shape i = torch.ones_like(d) x = torch.arange(0, w, dtype=torch.float, **kwargs) y = torch.arange(0, h, dtype=torch.float, **kwargs) y, x = torch.stack(torch.meshgrid(y, x, indexing="ij")) y = y.view(1, 1, h, w).repeat(b, n, 1, 1) x = x.view(1, 1, h, w).repeat(b, n, 1, 1) coords = torch.stack([x, y, d], dim=2) index = torch.arange(0, n, dtype=torch.float, **kwargs) index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w) return coords, index def patchify(x, patch_size=3): """extract patches from video""" b, n, c, h, w = x.shape x = x.view(b * n, c, h, w) y = F.unfold(x, patch_size) y = y.transpose(1, 2) return y.reshape(b, -1, c, patch_size, patch_size) def pyramidify(fmap, lvls=[1]): """turn fmap into a pyramid""" b, n, c, h, w = fmap.shape pyramid = [] for lvl in lvls: gmap = F.avg_pool2d(fmap.view(b * n, c, h, w), lvl, stride=lvl) pyramid += [gmap.view(b, n, c, h // lvl, w // lvl)] return pyramid def all_pairs_exclusive(n, **kwargs): ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs)) k = ii != jj return ii[k].reshape(-1), jj[k].reshape(-1) def set_depth(patches, depth): patches[..., 2, :, :] = depth[..., None, None] return patches def flatmeshgrid(*args, **kwargs): grid = torch.meshgrid(*args, **kwargs) return (x.reshape(-1) for x in grid)