Spaces:
Runtime error
Runtime error
| import torch | |
| import lietorch | |
| import numpy as np | |
| from lietorch import SE3 | |
| from factor_graph import FactorGraph | |
| class DroidFrontend: | |
| def __init__(self, net, video, args): | |
| self.video = video | |
| self.update_op = net.update | |
| self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample) | |
| # local optimization window | |
| self.t0 = 0 | |
| self.t1 = 0 | |
| # frontent variables | |
| self.is_initialized = False | |
| self.count = 0 | |
| self.max_age = 25 | |
| self.iters1 = 4 | |
| self.iters2 = 2 | |
| self.warmup = args.warmup | |
| self.beta = args.beta | |
| self.frontend_nms = args.frontend_nms | |
| self.keyframe_thresh = args.keyframe_thresh | |
| self.frontend_window = args.frontend_window | |
| self.frontend_thresh = args.frontend_thresh | |
| self.frontend_radius = args.frontend_radius | |
| def __update(self): | |
| """ add edges, perform update """ | |
| self.count += 1 | |
| self.t1 += 1 | |
| if self.graph.corr is not None: | |
| self.graph.rm_factors(self.graph.age > self.max_age, store=True) | |
| self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0), | |
| rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True) | |
| self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0, | |
| self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1]) | |
| for itr in range(self.iters1): | |
| self.graph.update(None, None, use_inactive=True) | |
| # set initial pose for next frame | |
| poses = SE3(self.video.poses) | |
| d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True) | |
| if d.item() < self.keyframe_thresh: | |
| self.graph.rm_keyframe(self.t1 - 2) | |
| with self.video.get_lock(): | |
| self.video.counter.value -= 1 | |
| self.t1 -= 1 | |
| else: | |
| for itr in range(self.iters2): | |
| self.graph.update(None, None, use_inactive=True) | |
| # set pose for next itration | |
| self.video.poses[self.t1] = self.video.poses[self.t1-1] | |
| self.video.disps[self.t1] = self.video.disps[self.t1-1].mean() | |
| # update visualization | |
| self.video.dirty[self.graph.ii.min():self.t1] = True | |
| def __initialize(self): | |
| """ initialize the SLAM system """ | |
| self.t0 = 0 | |
| self.t1 = self.video.counter.value | |
| self.graph.add_neighborhood_factors(self.t0, self.t1, r=3) | |
| for itr in range(8): | |
| self.graph.update(1, use_inactive=True) | |
| self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False) | |
| for itr in range(8): | |
| self.graph.update(1, use_inactive=True) | |
| # self.video.normalize() | |
| self.video.poses[self.t1] = self.video.poses[self.t1-1].clone() | |
| self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean() | |
| # initialization complete | |
| self.is_initialized = True | |
| self.last_pose = self.video.poses[self.t1-1].clone() | |
| self.last_disp = self.video.disps[self.t1-1].clone() | |
| self.last_time = self.video.tstamp[self.t1-1].clone() | |
| with self.video.get_lock(): | |
| self.video.ready.value = 1 | |
| self.video.dirty[:self.t1] = True | |
| self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True) | |
| def __call__(self): | |
| """ main update """ | |
| # do initialization | |
| if not self.is_initialized and self.video.counter.value == self.warmup: | |
| self.__initialize() | |
| # do update | |
| elif self.is_initialized and self.t1 < self.video.counter.value: | |
| self.__update() | |