File size: 4,799 Bytes
2ac1c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Union, Tuple, List

import numpy as np
import torch
from skimage import measure


class MeshExtractResult:
    def __init__(self, verts, faces, vertex_attrs=None, res=64):
        self.verts = verts
        self.faces = faces.long()
        self.vertex_attrs = vertex_attrs
        self.face_normal = self.comput_face_normals()
        self.vert_normal = self.comput_v_normals()
        self.res = res
        self.success = verts.shape[0] != 0 and faces.shape[0] != 0

        # training only
        self.tsdf_v = None
        self.tsdf_s = None
        self.reg_loss = None

    def comput_face_normals(self):
        i0 = self.faces[..., 0].long()
        i1 = self.faces[..., 1].long()
        i2 = self.faces[..., 2].long()

        v0 = self.verts[i0, :]
        v1 = self.verts[i1, :]
        v2 = self.verts[i2, :]
        face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
        face_normals = torch.nn.functional.normalize(face_normals, dim=1)
        return face_normals[:, None, :].repeat(1, 3, 1)

    def comput_v_normals(self):
        i0 = self.faces[..., 0].long()
        i1 = self.faces[..., 1].long()
        i2 = self.faces[..., 2].long()

        v0 = self.verts[i0, :]
        v1 = self.verts[i1, :]
        v2 = self.verts[i2, :]
        face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
        v_normals = torch.zeros_like(self.verts)
        v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
        v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
        v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)

        v_normals = torch.nn.functional.normalize(v_normals, dim=1)
        return v_normals


def center_vertices(vertices):
    """Translate the vertices so that bounding box is centered at zero."""
    vert_min = vertices.min(dim=0)[0]
    vert_max = vertices.max(dim=0)[0]
    vert_center = 0.5 * (vert_min + vert_max)
    return vertices - vert_center


class SurfaceExtractor:
    def _compute_box_stat(
        self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int
    ):
        if isinstance(bounds, float):
            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]

        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
        bbox_size = bbox_max - bbox_min
        grid_size = [
            int(octree_resolution) + 1,
            int(octree_resolution) + 1,
            int(octree_resolution) + 1,
        ]
        return grid_size, bbox_min, bbox_size

    def run(self, *args, **kwargs):
        return NotImplementedError

    def __call__(self, grid_logits, **kwargs):
        outputs = []
        for i in range(grid_logits.shape[0]):
            try:
                verts, faces = self.run(grid_logits[i], **kwargs)
                outputs.append(
                    MeshExtractResult(
                        verts=verts.float(),
                        faces=faces,
                        res=kwargs["octree_resolution"],
                    )
                )

            except Exception:
                import traceback

                traceback.print_exc()
                outputs.append(None)

        return outputs


class MCSurfaceExtractor(SurfaceExtractor):
    def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
        verts, faces, normals, _ = measure.marching_cubes(
            grid_logit.float().cpu().numpy(), mc_level, method="lewiner"
        )
        grid_size, bbox_min, bbox_size = self._compute_box_stat(
            bounds, octree_resolution
        )
        verts = verts / grid_size * bbox_size + bbox_min
        verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32)
        faces = torch.tensor(
            np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long
        )
        faces = faces[:, [2, 1, 0]]
        return verts, faces


class DMCSurfaceExtractor(SurfaceExtractor):
    def run(self, grid_logit, *, octree_resolution, **kwargs):
        device = grid_logit.device
        if not hasattr(self, "dmc"):
            try:
                from diso import DiffDMC
            except:
                raise ImportError(
                    "Please install diso via `pip install diso`, or set mc_algo to 'mc'"
                )
            self.dmc = DiffDMC(dtype=torch.float32).to(device)
        sdf = -grid_logit / octree_resolution
        sdf = sdf.to(torch.float32).contiguous()
        verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
        grid_size, bbox_min, bbox_size = self._compute_box_stat(
            kwargs["bounds"], octree_resolution
        )
        verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"]
        return verts, faces