Spaces:
Sleeping
Sleeping
import time | |
import torch | |
import trimesh | |
import numpy as np | |
import torch.optim as optim | |
from torch import autograd | |
from torch.utils.data import TensorDataset, DataLoader | |
from .common import make_3d_grid | |
from .utils import libmcubes | |
from .utils.libmise import MISE | |
from .utils.libsimplify import simplify_mesh | |
from .common import transform_pointcloud | |
class Generator3D(object): | |
''' Generator class for DVRs. | |
It provides functions to generate the final mesh as well refining options. | |
Args: | |
model (nn.Module): trained DVR model | |
points_batch_size (int): batch size for points evaluation | |
threshold (float): threshold value | |
refinement_step (int): number of refinement steps | |
device (device): pytorch device | |
resolution0 (int): start resolution for MISE | |
upsampling steps (int): number of upsampling steps | |
with_normals (bool): whether normals should be estimated | |
padding (float): how much padding should be used for MISE | |
simplify_nfaces (int): number of faces the mesh should be simplified to | |
refine_max_faces (int): max number of faces which are used as batch | |
size for refinement process (we added this functionality in this | |
work) | |
''' | |
def __init__( | |
self, | |
model, | |
points_batch_size=100000, | |
threshold=0.5, | |
refinement_step=0, | |
device=None, | |
resolution0=16, | |
upsampling_steps=3, | |
with_normals=False, | |
padding=0.1, | |
simplify_nfaces=None, | |
with_color=False, | |
refine_max_faces=10000 | |
): | |
self.model = model.to(device) | |
self.points_batch_size = points_batch_size | |
self.refinement_step = refinement_step | |
self.threshold = threshold | |
self.device = device | |
self.resolution0 = resolution0 | |
self.upsampling_steps = upsampling_steps | |
self.with_normals = with_normals | |
self.padding = padding | |
self.simplify_nfaces = simplify_nfaces | |
self.with_color = with_color | |
self.refine_max_faces = refine_max_faces | |
def generate_mesh(self, data, return_stats=True): | |
''' Generates the output mesh. | |
Args: | |
data (tensor): data tensor | |
return_stats (bool): whether stats should be returned | |
''' | |
self.model.eval() | |
device = self.device | |
stats_dict = {} | |
inputs = data.get('inputs', torch.empty(1, 0)).to(device) | |
kwargs = {} | |
c = self.model.encode_inputs(inputs) | |
mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs) | |
return mesh, stats_dict | |
def generate_meshes(self, data, return_stats=True): | |
''' Generates the output meshes with data of batch size >=1 | |
Args: | |
data (tensor): data tensor | |
return_stats (bool): whether stats should be returned | |
''' | |
self.model.eval() | |
device = self.device | |
stats_dict = {} | |
inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device) | |
meshes = [] | |
for i in range(inputs.shape[0]): | |
input_i = inputs[i].unsqueeze(0) | |
c = self.model.encode_inputs(input_i) | |
mesh = self.generate_from_latent(c, stats_dict=stats_dict) | |
meshes.append(mesh) | |
return meshes | |
def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True): | |
''' Generates a point cloud from the mesh. | |
Args: | |
mesh (trimesh): mesh | |
data (dict): data dictionary | |
n_points (int): number of point cloud points | |
scale_back (bool): whether to undo scaling (requires a scale | |
matrix in data dictionary) | |
''' | |
pcl = mesh.sample(n_points).astype(np.float32) | |
if scale_back: | |
scale_mat = data.get('camera.scale_mat_0', None) | |
if scale_mat is not None: | |
pcl = transform_pointcloud(pcl, scale_mat[0]) | |
else: | |
print('Warning: No scale_mat found!') | |
pcl_out = trimesh.Trimesh(vertices=pcl, process=False) | |
return pcl_out | |
def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs): | |
''' Generates mesh from latent. | |
Args: | |
c (tensor): latent conditioned code c | |
pl (tensor): predicted plane parameters | |
stats_dict (dict): stats dictionary | |
''' | |
threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
t0 = time.time() | |
# Compute bounding box size | |
box_size = 1 + self.padding | |
# Shortcut | |
if self.upsampling_steps == 0: | |
nx = self.resolution0 | |
pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3) | |
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() | |
value_grid = values.reshape(nx, nx, nx) | |
else: | |
mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold) | |
points = mesh_extractor.query() | |
while points.shape[0] != 0: | |
# Query points | |
pointsf = torch.FloatTensor(points).to(self.device) | |
# Normalize to bounding box | |
pointsf = 2 * pointsf / mesh_extractor.resolution | |
pointsf = box_size * (pointsf - 1.0) | |
# Evaluate model and update | |
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() | |
values = values.astype(np.float64) | |
mesh_extractor.update(points, values) | |
points = mesh_extractor.query() | |
value_grid = mesh_extractor.to_dense() | |
# Extract mesh | |
stats_dict['time (eval points)'] = time.time() - t0 | |
mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict) | |
return mesh | |
def eval_points(self, p, c=None, pl=None, **kwargs): | |
''' Evaluates the occupancy values for the points. | |
Args: | |
p (tensor): points | |
c (tensor): latent conditioned code c | |
''' | |
p_split = torch.split(p, self.points_batch_size) | |
occ_hats = [] | |
for pi in p_split: | |
pi = pi.unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
occ_hat = self.model.decode(pi, c, pl, **kwargs).logits | |
occ_hats.append(occ_hat.squeeze(0).detach().cpu()) | |
occ_hat = torch.cat(occ_hats, dim=0) | |
return occ_hat | |
def extract_mesh(self, occ_hat, c=None, stats_dict=dict()): | |
''' Extracts the mesh from the predicted occupancy grid. | |
Args: | |
occ_hat (tensor): value grid of occupancies | |
c (tensor): latent conditioned code c | |
stats_dict (dict): stats dictionary | |
''' | |
# Some short hands | |
n_x, n_y, n_z = occ_hat.shape | |
box_size = 1 + self.padding | |
threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
# Make sure that mesh is watertight | |
t0 = time.time() | |
occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6) | |
vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold) | |
stats_dict['time (marching cubes)'] = time.time() - t0 | |
# Strange behaviour in libmcubes: vertices are shifted by 0.5 | |
vertices -= 0.5 | |
# Undo padding | |
vertices -= 1 | |
# Normalize to bounding box | |
vertices /= np.array([n_x - 1, n_y - 1, n_z - 1]) | |
vertices *= 2 | |
vertices = box_size * (vertices - 1) | |
# mesh_pymesh = pymesh.form_mesh(vertices, triangles) | |
# mesh_pymesh = fix_pymesh(mesh_pymesh) | |
# Estimate normals if needed | |
if self.with_normals and not vertices.shape[0] == 0: | |
t0 = time.time() | |
normals = self.estimate_normals(vertices, c) | |
stats_dict['time (normals)'] = time.time() - t0 | |
else: | |
normals = None | |
# Create mesh | |
mesh = trimesh.Trimesh( | |
vertices, | |
triangles, | |
vertex_normals=normals, | |
# vertex_colors=vertex_colors, | |
process=False | |
) | |
# Directly return if mesh is empty | |
if vertices.shape[0] == 0: | |
return mesh | |
# TODO: normals are lost here | |
if self.simplify_nfaces is not None: | |
t0 = time.time() | |
mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.) | |
stats_dict['time (simplify)'] = time.time() - t0 | |
# Refine mesh | |
if self.refinement_step > 0: | |
t0 = time.time() | |
self.refine_mesh(mesh, occ_hat, c) | |
stats_dict['time (refine)'] = time.time() - t0 | |
# Estimate Vertex Colors | |
if self.with_color and not vertices.shape[0] == 0: | |
t0 = time.time() | |
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c) | |
stats_dict['time (color)'] = time.time() - t0 | |
mesh = trimesh.Trimesh( | |
vertices=mesh.vertices, | |
faces=mesh.faces, | |
vertex_normals=mesh.vertex_normals, | |
vertex_colors=vertex_colors, | |
process=False | |
) | |
return mesh | |
def estimate_colors(self, vertices, c=None): | |
''' Estimates vertex colors by evaluating the texture field. | |
Args: | |
vertices (numpy array): vertices of the mesh | |
c (tensor): latent conditioned code c | |
''' | |
device = self.device | |
vertices = torch.FloatTensor(vertices) | |
vertices_split = torch.split(vertices, self.points_batch_size) | |
colors = [] | |
for vi in vertices_split: | |
vi = vi.to(device) | |
with torch.no_grad(): | |
ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu() | |
colors.append(ci) | |
colors = np.concatenate(colors, axis=0) | |
colors = np.clip(colors, 0, 1) | |
colors = (colors * 255).astype(np.uint8) | |
colors = np.concatenate( | |
[colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1 | |
) | |
return colors | |
def estimate_normals(self, vertices, c=None): | |
''' Estimates the normals by computing the gradient of the objective. | |
Args: | |
vertices (numpy array): vertices of the mesh | |
z (tensor): latent code z | |
c (tensor): latent conditioned code c | |
''' | |
device = self.device | |
vertices = torch.FloatTensor(vertices) | |
vertices_split = torch.split(vertices, self.points_batch_size) | |
normals = [] | |
c = c.unsqueeze(0) | |
for vi in vertices_split: | |
vi = vi.unsqueeze(0).to(device) | |
vi.requires_grad_() | |
occ_hat = self.model.decode(vi, c).logits | |
out = occ_hat.sum() | |
out.backward() | |
ni = -vi.grad | |
ni = ni / torch.norm(ni, dim=-1, keepdim=True) | |
ni = ni.squeeze(0).cpu().numpy() | |
normals.append(ni) | |
normals = np.concatenate(normals, axis=0) | |
return normals | |
def refine_mesh(self, mesh, occ_hat, c=None): | |
''' Refines the predicted mesh. | |
Args: | |
mesh (trimesh object): predicted mesh | |
occ_hat (tensor): predicted occupancy grid | |
c (tensor): latent conditioned code c | |
''' | |
self.model.eval() | |
# Some shorthands | |
n_x, n_y, n_z = occ_hat.shape | |
assert (n_x == n_y == n_z) | |
# threshold = np.log(self.threshold) - np.log(1. - self.threshold) | |
threshold = self.threshold | |
# Vertex parameter | |
v0 = torch.FloatTensor(mesh.vertices).to(self.device) | |
v = torch.nn.Parameter(v0.clone()) | |
# Faces of mesh | |
faces = torch.LongTensor(mesh.faces) | |
# detach c; otherwise graph needs to be retained | |
# caused by new Pytorch version? | |
c = c.detach() | |
# Start optimization | |
optimizer = optim.RMSprop([v], lr=1e-5) | |
# Dataset | |
ds_faces = TensorDataset(faces) | |
dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True) | |
# We updated the refinement algorithm to subsample faces; this is | |
# usefull when using a high extraction resolution / when working on | |
# small GPUs | |
it_r = 0 | |
while it_r < self.refinement_step: | |
for f_it in dataloader: | |
f_it = f_it[0].to(self.device) | |
optimizer.zero_grad() | |
# Loss | |
face_vertex = v[f_it] | |
eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0]) | |
eps = torch.FloatTensor(eps).to(self.device) | |
face_point = (face_vertex * eps[:, :, None]).sum(dim=1) | |
face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :] | |
face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :] | |
face_normal = torch.cross(face_v1, face_v2) | |
face_normal = face_normal / \ | |
(face_normal.norm(dim=1, keepdim=True) + 1e-10) | |
face_value = torch.cat( | |
[ | |
torch.sigmoid(self.model.decode(p_split, c).logits) | |
for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1) | |
], | |
dim=1 | |
) | |
normal_target = -autograd.grad([face_value.sum()], [face_point], | |
create_graph=True)[0] | |
normal_target = \ | |
normal_target / \ | |
(normal_target.norm(dim=1, keepdim=True) + 1e-10) | |
loss_target = (face_value - threshold).pow(2).mean() | |
loss_normal = \ | |
(face_normal - normal_target).pow(2).sum(dim=1).mean() | |
loss = loss_target + 0.01 * loss_normal | |
# Update | |
loss.backward() | |
optimizer.step() | |
# Update it_r | |
it_r += 1 | |
if it_r >= self.refinement_step: | |
break | |
mesh.vertices = v.data.cpu().numpy() | |
return mesh | |