MiniDPVO / mini_dpvo /dpvo.py
pablovela5620's picture
chore: Update dependencies and remove unused files
a8c8616
raw
history blame
No virus
13.9 kB
import torch
import numpy as np
import torch.nn.functional as F
from . import fastba
from . import altcorr
from . import lietorch
from .lietorch import SE3
from .net import VONet
from .utils import Timer, flatmeshgrid
from . import projective_ops as pops
autocast = torch.cuda.amp.autocast
Id = SE3.Identity(1, device="cuda")
class DPVO:
def __init__(self, cfg, network, ht=480, wd=640):
self.cfg = cfg
self.load_weights(network)
self.is_initialized = False
self.enable_timing = False
self.n = 0 # number of frames
self.m = 0 # number of patches
self.M = self.cfg.PATCHES_PER_FRAME
self.N = self.cfg.BUFFER_SIZE
self.ht = ht # image height
self.wd = wd # image width
DIM = self.DIM
RES = self.RES
### state attributes ###
self.tlist = []
self.counter = 0
# dummy image for visualization
self.image_ = torch.zeros(self.ht, self.wd, 3, dtype=torch.uint8, device="cpu")
self.tstamps_ = torch.zeros(self.N, dtype=torch.float64, device="cuda")
self.poses_ = torch.zeros(self.N, 7, dtype=torch.float32, device="cuda")
self.patches_ = torch.zeros(
self.N, self.M, 3, self.P, self.P, dtype=torch.float, device="cuda"
)
self.intrinsics_ = torch.zeros(self.N, 4, dtype=torch.float, device="cuda")
self.points_ = torch.zeros(self.N * self.M, 3, dtype=torch.float, device="cuda")
self.colors_ = torch.zeros(self.N, self.M, 3, dtype=torch.uint8, device="cuda")
self.index_ = torch.zeros(self.N, self.M, dtype=torch.long, device="cuda")
self.index_map_ = torch.zeros(self.N, dtype=torch.long, device="cuda")
### network attributes ###
self.mem = 32
if self.cfg.MIXED_PRECISION:
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.half}
else:
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.float}
self.imap_ = torch.zeros(self.mem, self.M, DIM, **kwargs)
self.gmap_ = torch.zeros(self.mem, self.M, 128, self.P, self.P, **kwargs)
ht = ht // RES
wd = wd // RES
self.fmap1_ = torch.zeros(1, self.mem, 128, ht // 1, wd // 1, **kwargs)
self.fmap2_ = torch.zeros(1, self.mem, 128, ht // 4, wd // 4, **kwargs)
# feature pyramid
self.pyramid = (self.fmap1_, self.fmap2_)
self.net = torch.zeros(1, 0, DIM, **kwargs)
self.ii = torch.as_tensor([], dtype=torch.long, device="cuda")
self.jj = torch.as_tensor([], dtype=torch.long, device="cuda")
self.kk = torch.as_tensor([], dtype=torch.long, device="cuda")
# initialize poses to identity matrix
self.poses_[:, 6] = 1.0
# store relative poses for removed frames
self.delta = {}
def load_weights(self, network):
# load network from checkpoint file
if isinstance(network, str):
from collections import OrderedDict
state_dict = torch.load(network)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if "update.lmbda" not in k:
new_state_dict[k.replace("module.", "")] = v
self.network = VONet()
self.network.load_state_dict(new_state_dict)
else:
self.network = network
# steal network attributes
self.DIM = self.network.DIM
self.RES = self.network.RES
self.P = self.network.P
self.network.cuda()
self.network.eval()
# if self.cfg.MIXED_PRECISION:
# self.network.half()
@property
def poses(self):
return self.poses_.view(1, self.N, 7)
@property
def patches(self):
return self.patches_.view(1, self.N * self.M, 3, 3, 3)
@property
def intrinsics(self):
return self.intrinsics_.view(1, self.N, 4)
@property
def ix(self):
return self.index_.view(-1)
@property
def imap(self):
return self.imap_.view(1, self.mem * self.M, self.DIM)
@property
def gmap(self):
return self.gmap_.view(1, self.mem * self.M, 128, 3, 3)
def get_pose(self, t):
if t in self.traj:
return SE3(self.traj[t])
t0, dP = self.delta[t]
return dP * self.get_pose(t0)
def terminate(self):
"""interpolate missing poses"""
print("Terminating...")
self.traj = {}
for i in range(self.n):
current_t: int = self.tstamps_[i].item()
self.traj[current_t] = self.poses_[i]
poses = [self.get_pose(t) for t in range(self.counter)]
poses = lietorch.stack(poses, dim=0)
poses = poses.inv().data.cpu().numpy()
tstamps = np.array(self.tlist, dtype=np.float64)
print("Done!")
return poses, tstamps
def corr(self, coords, indicies=None):
"""local correlation volume"""
ii, jj = indicies if indicies is not None else (self.kk, self.jj)
ii1 = ii % (self.M * self.mem)
jj1 = jj % (self.mem)
corr1 = altcorr.corr(self.gmap, self.pyramid[0], coords / 1, ii1, jj1, 3)
corr2 = altcorr.corr(self.gmap, self.pyramid[1], coords / 4, ii1, jj1, 3)
return torch.stack([corr1, corr2], -1).view(1, len(ii), -1)
def reproject(self, indicies=None):
"""reproject patch k from i -> j"""
(ii, jj, kk) = indicies if indicies is not None else (self.ii, self.jj, self.kk)
coords = pops.transform(
SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk
)
return coords.permute(0, 1, 4, 2, 3).contiguous()
def append_factors(self, ii, jj):
self.jj = torch.cat([self.jj, jj])
self.kk = torch.cat([self.kk, ii])
self.ii = torch.cat([self.ii, self.ix[ii]])
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
self.net = torch.cat([self.net, net], dim=1)
def remove_factors(self, m):
self.ii = self.ii[~m]
self.jj = self.jj[~m]
self.kk = self.kk[~m]
self.net = self.net[:, ~m]
def motion_probe(self):
"""kinda hacky way to ensure enough motion for initialization"""
kk = torch.arange(self.m - self.M, self.m, device="cuda")
jj = self.n * torch.ones_like(kk)
ii = self.ix[kk]
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
coords = self.reproject(indicies=(ii, jj, kk))
with autocast(enabled=self.cfg.MIXED_PRECISION):
corr = self.corr(coords, indicies=(kk, jj))
ctx = self.imap[:, kk % (self.M * self.mem)]
net, (delta, weight, _) = self.network.update(
net, ctx, corr, None, ii, jj, kk
)
return torch.quantile(delta.norm(dim=-1).float(), 0.5)
def motionmag(self, i, j):
k = (self.ii == i) & (self.jj == j)
ii = self.ii[k]
jj = self.jj[k]
kk = self.kk[k]
flow = pops.flow_mag(
SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk, beta=0.5
)
return flow.mean().item()
def keyframe(self):
i = self.n - self.cfg.KEYFRAME_INDEX - 1
j = self.n - self.cfg.KEYFRAME_INDEX + 1
m = self.motionmag(i, j) + self.motionmag(j, i)
if m / 2 < self.cfg.KEYFRAME_THRESH:
k = self.n - self.cfg.KEYFRAME_INDEX
t0 = self.tstamps_[k - 1].item()
t1 = self.tstamps_[k].item()
dP = SE3(self.poses_[k]) * SE3(self.poses_[k - 1]).inv()
self.delta[t1] = (t0, dP)
to_remove = (self.ii == k) | (self.jj == k)
self.remove_factors(to_remove)
self.kk[self.ii > k] -= self.M
self.ii[self.ii > k] -= 1
self.jj[self.jj > k] -= 1
for i in range(k, self.n - 1):
self.tstamps_[i] = self.tstamps_[i + 1]
self.colors_[i] = self.colors_[i + 1]
self.poses_[i] = self.poses_[i + 1]
self.patches_[i] = self.patches_[i + 1]
self.intrinsics_[i] = self.intrinsics_[i + 1]
self.imap_[i % self.mem] = self.imap_[(i + 1) % self.mem]
self.gmap_[i % self.mem] = self.gmap_[(i + 1) % self.mem]
self.fmap1_[0, i % self.mem] = self.fmap1_[0, (i + 1) % self.mem]
self.fmap2_[0, i % self.mem] = self.fmap2_[0, (i + 1) % self.mem]
self.n -= 1
self.m -= self.M
to_remove = self.ix[self.kk] < self.n - self.cfg.REMOVAL_WINDOW
self.remove_factors(to_remove)
def update(self):
with Timer("other", enabled=self.enable_timing):
coords = self.reproject()
with autocast(enabled=True):
corr = self.corr(coords)
ctx = self.imap[:, self.kk % (self.M * self.mem)]
self.net, (delta, weight, _) = self.network.update(
self.net, ctx, corr, None, self.ii, self.jj, self.kk
)
lmbda = torch.as_tensor([1e-4], device="cuda")
weight = weight.float()
target = coords[..., self.P // 2, self.P // 2] + delta.float()
with Timer("BA", enabled=self.enable_timing):
t0 = self.n - self.cfg.OPTIMIZATION_WINDOW if self.is_initialized else 1
t0 = max(t0, 1)
try:
fastba.BA(
self.poses,
self.patches,
self.intrinsics,
target,
weight,
lmbda,
self.ii,
self.jj,
self.kk,
t0,
self.n,
2,
)
except:
print("Warning BA failed...")
points = pops.point_cloud(
SE3(self.poses),
self.patches[:, : self.m],
self.intrinsics,
self.ix[: self.m],
)
points = (points[..., 1, 1, :3] / points[..., 1, 1, 3:]).reshape(-1, 3)
self.points_[: len(points)] = points[:]
def __edges_all(self):
return flatmeshgrid(
torch.arange(0, self.m, device="cuda"),
torch.arange(0, self.n, device="cuda"),
indexing="ij",
)
def __edges_forw(self):
r = self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - r), 0)
t1 = self.M * max((self.n - 1), 0)
return flatmeshgrid(
torch.arange(t0, t1, device="cuda"),
torch.arange(self.n - 1, self.n, device="cuda"),
indexing="ij",
)
def __edges_back(self):
r = self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - 1), 0)
t1 = self.M * max((self.n - 0), 0)
return flatmeshgrid(
torch.arange(t0, t1, device="cuda"),
torch.arange(max(self.n - r, 0), self.n, device="cuda"),
indexing="ij",
)
def __call__(self, tstamp: int, image, intrinsics) -> None:
"""track new frame"""
if (self.n + 1) >= self.N:
raise Exception(
f'The buffer size is too small. You can increase it using "--buffer {self.N*2}"'
)
image = 2 * (image[None, None] / 255.0) - 0.5
with autocast(enabled=self.cfg.MIXED_PRECISION):
fmap, gmap, imap, patches, _, clr = self.network.patchify(
image,
patches_per_image=self.cfg.PATCHES_PER_FRAME,
gradient_bias=self.cfg.GRADIENT_BIAS,
return_color=True,
)
### update state attributes ###
self.tlist.append(tstamp)
self.tstamps_[self.n] = self.counter
self.intrinsics_[self.n] = intrinsics / self.RES
# color info for visualization
clr = (clr[0, :, [2, 1, 0]] + 0.5) * (255.0 / 2)
self.colors_[self.n] = clr.to(torch.uint8)
self.index_[self.n + 1] = self.n + 1
self.index_map_[self.n + 1] = self.m + self.M
if self.n > 1:
if self.cfg.MOTION_MODEL == "DAMPED_LINEAR":
P1 = SE3(self.poses_[self.n - 1])
P2 = SE3(self.poses_[self.n - 2])
xi = self.cfg.MOTION_DAMPING * (P1 * P2.inv()).log()
tvec_qvec = (SE3.exp(xi) * P1).data
self.poses_[self.n] = tvec_qvec
else:
tvec_qvec = self.poses[self.n - 1]
self.poses_[self.n] = tvec_qvec
# TODO better depth initialization
patches[:, :, 2] = torch.rand_like(patches[:, :, 2, 0, 0, None, None])
if self.is_initialized:
s = torch.median(self.patches_[self.n - 3 : self.n, :, 2])
patches[:, :, 2] = s
self.patches_[self.n] = patches
### update network attributes ###
self.imap_[self.n % self.mem] = imap.squeeze()
self.gmap_[self.n % self.mem] = gmap.squeeze()
self.fmap1_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 1, 1)
self.fmap2_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 4, 4)
self.counter += 1
if self.n > 0 and not self.is_initialized:
if self.motion_probe() < 2.0:
self.delta[self.counter - 1] = (self.counter - 2, Id[0])
return
self.n += 1
self.m += self.M
# relative pose
self.append_factors(*self.__edges_forw())
self.append_factors(*self.__edges_back())
if self.n == 8 and not self.is_initialized:
self.is_initialized = True
for itr in range(12):
self.update()
elif self.is_initialized:
self.update()
self.keyframe()