MiniDPVO / mini_dpvo /utils.py
pablovela5620's picture
initial commit with working dpvo
899c526
raw
history blame contribute delete
No virus
2.58 kB
import torch
import torch.nn.functional as F
all_times = []
class Timer:
def __init__(self, name: str, enabled: bool = True):
self.name = name
self.enabled = enabled
if self.enabled:
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
if self.enabled:
self.start.record()
def __exit__(self, type, value, traceback):
global all_times
if self.enabled:
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
all_times.append(elapsed)
print(f"{self.name}: {elapsed:.2f}ms")
def coords_grid(b, n, h, w, **kwargs):
"""coordinate grid"""
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
return coords[[1, 0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)
def coords_grid_with_index(d, **kwargs):
"""coordinate grid with frame index"""
b, n, h, w = d.shape
i = torch.ones_like(d)
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
x = x.view(1, 1, h, w).repeat(b, n, 1, 1)
coords = torch.stack([x, y, d], dim=2)
index = torch.arange(0, n, dtype=torch.float, **kwargs)
index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w)
return coords, index
def patchify(x, patch_size=3):
"""extract patches from video"""
b, n, c, h, w = x.shape
x = x.view(b * n, c, h, w)
y = F.unfold(x, patch_size)
y = y.transpose(1, 2)
return y.reshape(b, -1, c, patch_size, patch_size)
def pyramidify(fmap, lvls=[1]):
"""turn fmap into a pyramid"""
b, n, c, h, w = fmap.shape
pyramid = []
for lvl in lvls:
gmap = F.avg_pool2d(fmap.view(b * n, c, h, w), lvl, stride=lvl)
pyramid += [gmap.view(b, n, c, h // lvl, w // lvl)]
return pyramid
def all_pairs_exclusive(n, **kwargs):
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
k = ii != jj
return ii[k].reshape(-1), jj[k].reshape(-1)
def set_depth(patches, depth):
patches[..., 2, :, :] = depth[..., None, None]
return patches
def flatmeshgrid(*args, **kwargs):
grid = torch.meshgrid(*args, **kwargs)
return (x.reshape(-1) for x in grid)