import copy import torch import numpy as np import torch.nn as nn from functools import partial from torch.nn import functional as F from croco.models.blocks import Block from dust3r.model import AsymmetricCroCo3DStereo class SpatialMemory(): def __init__(self, norm_q, norm_k, norm_v, mem_dropout=None, long_mem_size=4000, work_mem_size=5, attn_thresh=5e-4, sim_thresh=0.95, save_attn=False): self.norm_q = norm_q self.norm_k = norm_k self.norm_v = norm_v self.mem_dropout = mem_dropout self.attn_thresh = attn_thresh self.long_mem_size = long_mem_size self.work_mem_size = work_mem_size self.top_k = long_mem_size self.save_attn = save_attn self.sim_thresh = sim_thresh self.init_mem() def init_mem(self): self.mem_k = None self.mem_v = None self.mem_c = None self.mem_count = None self.mem_attn = None self.mem_pts = None self.mem_imgs = None self.lm = 0 self.wm = 0 if self.save_attn: self.attn_vis = None def add_mem_k(self, feat): if self.mem_k is None: self.mem_k = feat else: self.mem_k = torch.cat((self.mem_k, feat), dim=1) return self.mem_k def add_mem_v(self, feat): if self.mem_v is None: self.mem_v = feat else: self.mem_v = torch.cat((self.mem_v, feat), dim=1) return self.mem_v def add_mem_c(self, feat): if self.mem_c is None: self.mem_c = feat else: self.mem_c = torch.cat((self.mem_c, feat), dim=1) return self.mem_c def add_mem_pts(self, pts_cur): if pts_cur is not None: if self.mem_pts is None: self.mem_pts = pts_cur else: self.mem_pts = torch.cat((self.mem_pts, pts_cur), dim=1) def add_mem_img(self, img_cur): if img_cur is not None: if self.mem_imgs is None: self.mem_imgs = img_cur else: self.mem_imgs = torch.cat((self.mem_imgs, img_cur), dim=1) def add_mem(self, feat_k, feat_v, pts_cur=None, img_cur=None): if self.mem_count is None: self.mem_count = torch.zeros_like(feat_k[:, :, :1]) self.mem_attn = torch.zeros_like(feat_k[:, :, :1]) else: self.mem_count += 1 self.mem_count = torch.cat((self.mem_count, torch.zeros_like(feat_k[:, :, :1])), dim=1) self.mem_attn = torch.cat((self.mem_attn, torch.zeros_like(feat_k[:, :, :1])), dim=1) self.add_mem_k(feat_k) self.add_mem_v(feat_v) self.add_mem_pts(pts_cur) self.add_mem_img(img_cur) def check_sim(self, feat_k, thresh=0.7): # Do correlation with working memory if self.mem_k is None or thresh==1.0: return False wmem_size = self.wm * 196 # wm: BS, T, 196, C wm = self.mem_k[:, -wmem_size:].reshape(self.mem_k.shape[0], -1, 196, self.mem_k.shape[-1]) feat_k_norm = F.normalize(feat_k, p=2, dim=-1) wm_norm = F.normalize(wm, p=2, dim=-1) corr = torch.einsum('bpc,btpc->btp', feat_k_norm, wm_norm) mean_corr = torch.mean(corr, dim=-1) if mean_corr.max() > thresh: print('Similarity detected:', mean_corr.max()) return True return False def add_mem_check(self, feat_k, feat_v, pts_cur=None, img_cur=None): if self.check_sim(feat_k, thresh=self.sim_thresh): return self.add_mem(feat_k, feat_v, pts_cur, img_cur) self.wm += 1 if self.wm > self.work_mem_size: self.wm -= 1 if self.long_mem_size == 0: self.mem_k = self.mem_k[:, 196:] self.mem_v = self.mem_v[:, 196:] self.mem_count = self.mem_count[:, 196:] self.mem_attn = self.mem_attn[:, 196:] print('Memory pruned:', self.mem_k.shape) else: self.lm += 196 # TODO: Change this to the actual size of the memory bank if self.lm > self.long_mem_size: self.memory_prune() self.lm = self.top_k - self.wm * 196 def memory_read(self, feat, res=True): ''' Params: - feat: [bs, p, c] - mem_k: [bs, t, p, c] - mem_v: [bs, t, p, c] - mem_c: [bs, t, p, 1] ''' affinity = torch.einsum('bpc,bxc->bpx', self.norm_q(feat), self.norm_k(self.mem_k.reshape(self.mem_k.shape[0], -1, self.mem_k.shape[-1]))) affinity /= torch.sqrt(torch.tensor(feat.shape[-1]).float()) if self.mem_c is not None: affinity = affinity * self.mem_c.view(self.mem_c.shape[0], 1, -1) attn = torch.softmax(affinity, dim=-1) if self.save_attn: if self.attn_vis is None: self.attn_vis = attn.reshape(-1) else: self.attn_vis = torch.cat((self.attn_vis, attn.reshape(-1)), dim=0) if self.mem_dropout is not None: attn = self.mem_dropout(attn) if self.attn_thresh > 0: attn[attnbpc', attn, self.norm_v(self.mem_v.reshape(self.mem_v.shape[0], -1, self.mem_v.shape[-1]))) if res: out = out + feat total_attn = torch.sum(attn, dim=-2) self.mem_attn += total_attn[..., None] return out def memory_prune(self): weights = self.mem_attn / self.mem_count weights[self.mem_count', num_mem_a) class Spann3R(nn.Module): def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", use_feat=False, mem_pos_enc=False, memory_dropout=0.15): super(Spann3R, self).__init__() # config self.use_feat = use_feat self.mem_pos_enc = mem_pos_enc # DUSt3R self.dust3r = AsymmetricCroCo3DStereo.from_pretrained(dus3r_name, landscape_only=True) # Memory encoder self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout) self.set_attn_head() def set_memory_encoder(self, enc_depth=6, enc_embed_dim=1024, out_dim=1024, enc_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), memory_dropout=0.15): self.value_encoder = nn.ModuleList([ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.dust3r.rope if self.mem_pos_enc else None) for i in range(enc_depth)]) self.value_norm = norm_layer(enc_embed_dim) self.value_out = nn.Linear(enc_embed_dim, out_dim) if not self.use_feat: self.pos_patch_embed = copy.deepcopy(self.dust3r.patch_embed) self.pos_patch_embed.load_state_dict(self.dust3r.patch_embed.state_dict()) # Normalization layers self.norm_q = nn.LayerNorm(1024) self.norm_k = nn.LayerNorm(1024) self.norm_v = nn.LayerNorm(1024) self.mem_dropout = nn.Dropout(memory_dropout) def set_attn_head(self, enc_embed_dim=1024+768, out_dim=1024): self.attn_head_1 = nn.Sequential( nn.Linear(enc_embed_dim, enc_embed_dim), nn.GELU(), nn.Linear(enc_embed_dim, out_dim) ) self.attn_head_2 = nn.Sequential( nn.Linear(enc_embed_dim, enc_embed_dim), nn.GELU(), nn.Linear(enc_embed_dim, out_dim) ) def encode_image(self, view): img = view['img'] B = img.shape[0] im_shape = view.get('true_shape', torch.tensor(img.shape[-2:])[None].repeat(B, 1)) out, pos, _ = self.dust3r._encode_image(img, im_shape) return out, pos, im_shape def encode_image_pairs(self, view1, view2): img1 = view1['img'] img2 = view2['img'] B = img1.shape[0] shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) out, pos, _ = self.dust3r._encode_image(torch.cat((img1, img2), dim=0), torch.cat((shape1, shape2), dim=0)) out, out2 = out.chunk(2, dim=0) pos, pos2 = pos.chunk(2, dim=0) return out, out2, pos, pos2, shape1, shape2 def encode_frames(self, view1, view2, feat1, feat2, pos1, pos2, shape1, shape2): if feat1 is None: feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(view1, view2) else: feat1, pos1, shape1 = feat2, pos2, shape2 feat2, pos2, shape2 = self.encode_image(view2) return feat1, feat2, pos1, pos2, shape1, shape2 def encode_feat_key(self, feat1, feat2, num=1): feat = torch.cat((feat1, feat2), dim=-1) feat_k = getattr(self, f'attn_head_{num}')(feat) return feat_k def encode_value(self, x, pos): for block in self.value_encoder: x = block(x, pos) x = self.value_norm(x) x = self.value_out(x) return x def encode_cur_value(self, res1, dec1, pos1, shape1): if self.use_feat: cur_v = self.encode_value(dec1[-1], pos1) else: out, pos_v = self.pos_patch_embed(res1['pts3d'].permute(0, 3, 1, 2), true_shape=shape1) cur_v = self.encode_value(out, pos_v) return cur_v def decode(self, feat1, pos1, feat2, pos2): dec1, dec2 = self.dust3r._decoder(feat1, pos1, feat2, pos2) return dec1, dec2 def downstream_head(self, dec, true_shape, num=1): with torch.cuda.amp.autocast(enabled=False): res = self.dust3r._downstream_head(num, [tok.float() for tok in dec], true_shape) return res def find_initial_pair(self, graph, n_frames): view1, view2, pred1, pred2 = graph['view1'], graph['view2'], graph['pred1'], graph['pred2'] n_pairs = len(view1['idx']) conf_matrix = torch.zeros(n_frames, n_frames) for i in range(n_pairs): idx1, idx2 = view1['idx'][i], view2['idx'][i] conf1 = pred1['conf'][i] conf2 = pred2['conf'][i] conf1_sig = (conf1-1)/conf1 conf2_sig = (conf2-1)/conf2 conf = conf1_sig.mean() + conf2_sig.mean() conf_matrix[idx1, idx2] = conf pair_idx = np.unravel_index(conf_matrix.argmax(), conf_matrix.shape) print(f'init pair:{pair_idx}, conf: {conf_matrix.max()}') return pair_idx def find_next_best_view(self, frames, idx_todo, feat_fuse, pos1, shape1): best_conf = 0.0 from copy import deepcopy for i in idx_todo: view = frames[i] feat2, pos2, shape2 = self.encode_image(view) dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) res1 = self.downstream_head(dec1, shape1, 1) res2 = self.downstream_head(dec2, shape2, 2) conf1 = res1['conf'] conf2 = res2['conf'] conf1_sig = (conf1-1)/conf1 conf2_sig = (conf2-1)/conf2 total_conf_mean = conf1_sig.mean() + conf2_sig.mean() if total_conf_mean > best_conf: best_conf = total_conf_mean best_id = i best_dec1 = deepcopy(dec1) best_dec2 = deepcopy(dec2) best_res1 = deepcopy(res1) best_res2 = deepcopy(res2) best_feat2 = feat2 best_pos2 = pos2 best_shape2 = shape2 return best_id, best_dec1, best_dec2, best_res1, best_res2, best_feat2, best_pos2, best_shape2, best_conf def offline_reconstruction(self, frames, graph): n_frames = len(frames) idx_todo = list(range(n_frames)) idx_used = [] sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout) pair_idx = self.find_initial_pair(graph, n_frames) f1, f2 = frames[pair_idx[0]], frames[pair_idx[1]] idx_used.append(pair_idx[0]) idx_used.append(pair_idx[1]) # remove those idxs from idx_todo idx_todo.remove(pair_idx[0]) idx_todo.remove(pair_idx[1]) ##### Encode frames feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_image_pairs(f1, f2) feat_fuse = feat1 dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) ##### Regress pointmaps with torch.cuda.amp.autocast(enabled=False): res1 = self.downstream_head(dec1, shape1, 1) res2 = self.downstream_head(dec2, shape2, 2) ##### Encode feat key feat_k2 = None preds = None while True: if feat_k2 is not None: feat1 = feat2 pos1, shape1 = pos2, shape2 feat_fuse = sp_mem.memory_read(feat_k2, res=True) id_n, dec1, dec2, res1, res2, feat2, pos2, shape2, best_conf = self.find_next_best_view(frames, idx_todo, feat_fuse, pos2, shape2) idx_todo.remove(id_n) idx_used.append(id_n) print(f'next best view: {id_n}, conf: {best_conf}') # encode feat feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1) feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2) ##### Memory update cur_v = self.encode_cur_value(res1, dec1, pos1, shape1) sp_mem.add_mem_check(feat_k1, cur_v+feat_k1) res2['pts3d_in_other_view'] = res2.pop('pts3d') if preds is None: preds = [res1] preds_all = [(res1, res2)] else: res1['pts3d_in_other_view'] = res1.pop('pts3d') preds.append(res1) preds_all.append((res1, res2)) if len(idx_todo) == 0: break preds.append(res2) return preds, preds_all, idx_used def forward(self, frames, return_memory=False): if self.training: sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v, mem_dropout=self.mem_dropout, attn_thresh=0) else: sp_mem = SpatialMemory(self.norm_q, self.norm_k, self.norm_v) feat1, feat2, pos1, pos2, shape1, shape2 = None, None, None, None, None, None feat_k1, feat_k2 = None, None preds = None preds_all = [] for i in range(len(frames)): if i == len(frames)-1: break view1 = frames[i] view2 = frames[(i+1)] ##### Encode frames # feat1: [bs, p=196, c=1024] feat1, feat2, pos1, pos2, shape1, shape2 = self.encode_frames(view1, view2, feat1, feat2, pos1, pos2, shape1, shape2) ##### Memory readout if feat_k2 is not None: feat_fuse = sp_mem.memory_read(feat_k2, res=True) # feat_fuse = feat_fuse + feat1 else: feat_fuse = feat1 ##### Decode features # dec1[-1]: [bs, p, c=768] dec1, dec2 = self.decode(feat_fuse, pos1, feat2, pos2) ##### Encode feat key feat_k1 = self.encode_feat_key(feat1, dec1[-1], 1) feat_k2 = self.encode_feat_key(feat2, dec2[-1], 2) ##### Regress pointmaps with torch.cuda.amp.autocast(enabled=False): res1 = self.downstream_head(dec1, shape1, 1) res2 = self.downstream_head(dec2, shape2, 2) ##### Memory update cur_v = self.encode_cur_value(res1, dec1, pos1, shape1) if self.training: sp_mem.add_mem(feat_k1, cur_v+feat_k1) else: sp_mem.add_mem_check(feat_k1, cur_v+feat_k1) res2['pts3d_in_other_view'] = res2.pop('pts3d') if preds is None: preds = [res1] preds_all = [(res1, res2)] else: res1['pts3d_in_other_view'] = res1.pop('pts3d') preds.append(res1) preds_all.append((res1, res2)) preds.append(res2) if return_memory: return preds, preds_all, sp_mem return preds, preds_all