File size: 5,791 Bytes
414b431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
import torch
import threading
import mcubes
import trimesh
from utils.util_vis import show_att_on_image
from utils.camera import get_rotation_sphere

@torch.no_grad()
def get_dense_3D_grid(opt, var, N=None):
    batch_size = len(var.idx)
    N = N or opt.eval.vox_res
    # -0.6, 0.6
    range_min, range_max = opt.eval.range
    grid = torch.linspace(range_min, range_max, N+1, device=opt.device)
    points_3D = torch.stack(torch.meshgrid(grid, grid, grid, indexing='ij'), dim=-1) # [N, N, N, 3]
    # actually N+1 instead of N
    points_3D = points_3D.repeat(batch_size, 1, 1, 1, 1) # [B, N, N, N, 3]
    return points_3D

@torch.no_grad()
def compute_level_grid(opt, impl_network, latent_depth, latent_semantic, points_3D, images, vis_attn=False):
    # needed for amp
    latent_depth = latent_depth.to(torch.float32) if latent_depth is not None else None
    latent_semantic = latent_semantic.to(torch.float32) if latent_semantic is not None else None
    
    # process points in sliced way
    batch_size = points_3D.shape[0]
    N = points_3D.shape[1]
    assert N == points_3D.shape[2] == points_3D.shape[3]
    assert points_3D.shape[4] == 3
    
    points_3D = points_3D.view(batch_size, N, N*N, 3)
    occ = []
    attn = []
    for i in range(N):
        # [B, N*N, 3]
        points_slice = points_3D[:, i]
        # [B, N*N, 3] -> [B, N*N], [B, N*N, 1+feat_res**2]
        occ_slice, attn_slice = impl_network(latent_depth, latent_semantic, points_slice)
        occ.append(occ_slice)
        attn.append(attn_slice.detach())
    # [B, N, N*N] -> [B, N, N, N]
    occ = torch.stack(occ, dim=1).view(batch_size, N, N, N)
    occ = torch.sigmoid(occ)
    if vis_attn:
        N_global = 1
        feat_res = opt.H // opt.arch.win_size
        attn = torch.stack(attn, dim=1).view(batch_size, N, N, N, N_global+feat_res**2)
        # average along Z, [B, N, N, N_global+feat_res**2]
        attn = torch.mean(attn, dim=3)
        # [B, N, N, N_global] -> [B, N, N, 1]
        attn_global = attn[:, :, :, :N_global].sum(dim=-1, keepdim=True)
        # [B, N, N, feat_res, feat_res]
        attn_local = attn[:, :, :, N_global:].view(batch_size, N, N, feat_res, feat_res)
        # [B, N, N, feat_res, feat_res]
        attn_vis = attn_global.unsqueeze(-1) + attn_local
        # list of frame lists
        images_vis = []
        for b in range(batch_size):
            images_vis_sample = []
            for row in range(0, N, 8):
                if row % 16 == 0:
                    col_range = range(0, N//8*8+1, 8)
                else:
                    col_range = range(N//8*8, -1, -8)
                for col in col_range:
                    # [feat_res, feat_res], x is col
                    attn_curr = attn_vis[b, col, row]
                    attn_curr = torch.nn.functional.interpolate(
                        attn_curr.unsqueeze(0).unsqueeze(0), size=(opt.H, opt.W), 
                        mode='bilinear', align_corners=False
                    ).squeeze(0).squeeze(0).cpu().numpy()
                    attn_curr /= attn_curr.max()
                    # [feat_res, feat_res, 3]
                    image_curr = images[b].permute(1, 2, 0).cpu().numpy()
                    # merge the image and the attention
                    images_vis_sample.append(show_att_on_image(image_curr, attn_curr))
            images_vis.append(images_vis_sample)
    return occ, images_vis if vis_attn else None

@torch.no_grad()
def standardize_pc(pc):
    assert len(pc.shape) == 3
    pc_mean = pc.mean(dim=1, keepdim=True) 
    pc_zmean = pc - pc_mean
    origin_distance = (pc_zmean**2).sum(dim=2, keepdim=True).sqrt()
    scale = torch.sqrt(torch.sum(origin_distance**2, dim=1, keepdim=True) / pc.shape[1])
    pc_standardized = pc_zmean / (scale * 2)
    return pc_standardized

@torch.no_grad()
def normalize_pc(pc):
    assert len(pc.shape) == 3
    pc_mean = pc.mean(dim=1, keepdim=True) 
    pc_zmean = pc - pc_mean
    length_x = pc_zmean[:, :, 0].max(dim=-1)[0] - pc_zmean[:, :, 0].min(dim=-1)[0]
    length_y = pc_zmean[:, :, 1].max(dim=-1)[0] - pc_zmean[:, :, 1].min(dim=-1)[0]
    length_max = torch.stack([length_x, length_y], dim=-1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
    pc_normalized = pc_zmean / (length_max + 1.e-7)
    return pc_normalized

def convert_to_explicit(opt, level_grids, isoval=0., to_pointcloud=False):
    N = len(level_grids)
    meshes = [None]*N
    pointclouds = [None]*N if to_pointcloud else None
    threads = [threading.Thread(target=convert_to_explicit_worker,
                                args=(opt, i, level_grids[i], isoval, meshes),
                                kwargs=dict(pointclouds=pointclouds),
                                daemon=False) for i in range(N)]
    for t in threads: t.start()
    for t in threads: t.join()
    if to_pointcloud:
        pointclouds = np.stack(pointclouds, axis=0)
        return meshes, pointclouds
    else: return meshes

def convert_to_explicit_worker(opt, i, level_vox_i, isoval, meshes, pointclouds=None):
    # use marching cubes to convert implicit surface to mesh
    vertices, faces = mcubes.marching_cubes(level_vox_i, isovalue=isoval)
    assert(level_vox_i.shape[0]==level_vox_i.shape[1]==level_vox_i.shape[2])
    S = level_vox_i.shape[0]
    range_min, range_max = opt.eval.range
    # marching cubes treat every cube as unit length
    vertices = vertices/S*(range_max-range_min)+range_min
    mesh = trimesh.Trimesh(vertices, faces)
    meshes[i] = mesh
    if pointclouds is not None:
        # randomly sample on mesh to get uniform dense point cloud
        if len(mesh.triangles)!=0:
            points = mesh.sample(opt.eval.num_points)
        else: points = np.zeros([opt.eval.num_points, 3])
        pointclouds[i] = points