File size: 2,519 Bytes
b621857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-

import torch
from tqdm import tqdm
from einops import repeat
import numpy as np
from typing import Callable, Tuple, List, Union, Optional
from skimage import measure

from michelangelo.graphics.primitives import generate_dense_grid_points


@torch.no_grad()
def extract_geometry(geometric_func: Callable,
                     device: torch.device,
                     batch_size: int = 1,
                     bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
                     octree_depth: int = 7,
                     num_chunks: int = 10000,
                     disable: bool = True):
    """

    Args:
        geometric_func:
        device:
        bounds:
        octree_depth:
        batch_size:
        num_chunks:
        disable:

    Returns:

    """

    if isinstance(bounds, float):
        bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]

    bbox_min = np.array(bounds[0:3])
    bbox_max = np.array(bounds[3:6])
    bbox_size = bbox_max - bbox_min

    xyz_samples, grid_size, length = generate_dense_grid_points(
        bbox_min=bbox_min,
        bbox_max=bbox_max,
        octree_depth=octree_depth,
        indexing="ij"
    )
    xyz_samples = torch.FloatTensor(xyz_samples)

    batch_logits = []
    for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
                      desc="Implicit Function:", disable=disable, leave=False):
        queries = xyz_samples[start: start + num_chunks, :].to(device)
        batch_queries = repeat(queries, "p c -> b p c", b=batch_size)

        logits = geometric_func(batch_queries)
        batch_logits.append(logits.cpu())

    grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()

    mesh_v_f = []
    has_surface = np.zeros((batch_size,), dtype=np.bool_)
    for i in range(batch_size):
        try:
            vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
            vertices = vertices / grid_size * bbox_size + bbox_min
            # vertices[:, [0, 1]] = vertices[:, [1, 0]]
            mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
            has_surface[i] = True

        except ValueError:
            mesh_v_f.append((None, None))
            has_surface[i] = False

        except RuntimeError:
            mesh_v_f.append((None, None))
            has_surface[i] = False

    return mesh_v_f, has_surface