Spaces:
Runtime error
Runtime error
import torch | |
from einops import rearrange | |
# merge points from two views and add color information | |
def merge_points(mast3r_output, view1, view2, grid_size=0.01): | |
# get points from mast3r_output | |
points1 = mast3r_output[0]['pts3d'].detach() # B, H, W, 3 | |
points2 = mast3r_output[1]['pts3d_in_other_view'].detach() # B, H, W, 3 | |
shape = points1.shape | |
# add color information | |
colors = torch.stack([view1['img'], view2['img']], dim=1) # B, V, 3, H, W | |
colors = rearrange(colors, 'b v c h w -> b (v h w) c') # B, V * H * W, 3 | |
# merge points | |
points = torch.stack([points1, points2], dim=1) # B, V, H, W, 3 | |
points = rearrange(points, 'b v h w c -> b (v h w) c') # B, V * H * W, 3 | |
B, N, _ = points.shape | |
offset = torch.arange(1, B + 1, device=points.device) * N | |
# Center and normalize points | |
center = torch.mean(points, dim=1, keepdim=True) | |
points = points - center | |
scale = torch.max(torch.norm(points, dim=2, keepdim=True), dim=1, keepdim=True)[0] | |
points = points / scale | |
# concat points and colors | |
feat = torch.cat([points, colors], dim=-1) # B, V * H * W, 6 | |
data_dict = { | |
'coord': rearrange(points, 'b n c -> (b n) c'), | |
'color': rearrange(colors, 'b n c -> (b n) c'), | |
'feat': rearrange(feat, 'b n c -> (b n) c'), | |
'offset': offset, | |
'grid_size': grid_size, | |
'center': center, | |
'scale': scale, | |
'shape': shape, | |
} | |
return data_dict | |