kyleleey
first commit
98a77e0
raw
history blame
8.7 kB
# Cages code used from https://github.com/yifita/deep_cage
import torch
import numpy as np
import trimesh
def deform_with_MVC(cage, cage_deformed, cage_face, query, verbose=False):
"""
cage (B,C,3)
cage_deformed (B,C,3)
cage_face (B,F,3) int64
query (B,Q,3)
"""
weights, weights_unnormed = mean_value_coordinates_3D(query, cage, cage_face, verbose=True)
# weights = weights.detach()
deformed = torch.sum(weights.unsqueeze(-1)*cage_deformed.unsqueeze(1), dim=2)
if verbose:
return deformed, weights, weights_unnormed
return deformed
def loadInitCage(template):
init_cage_V, init_cage_F = read_trimesh(template)
init_cage_V = torch.from_numpy(init_cage_V[:,:3].astype(np.float32)).unsqueeze(0)*2.0
init_cage_F = torch.from_numpy(init_cage_F[:,:3].astype(np.int64)).unsqueeze(0)
return init_cage_V, init_cage_F
def read_trimesh(path):
mesh = trimesh.load(path)
return mesh.vertices, mesh.faces
# util functions from pytorch_points
PI = 3.1415927
def normalize_to_box(input):
"""
normalize point cloud to unit bounding box
center = (max - min)/2
scale = max(abs(x))
input: pc [N, P, dim] or [P, dim]
output: pc, centroid, furthest_distance
"""
if len(input.shape) == 2:
axis = 0
P = input.shape[0]
D = input.shape[1]
elif len(input.shape) == 3:
axis = 1
P = input.shape[1]
D = input.shape[2]
if isinstance(input, np.ndarray):
maxP = np.amax(input, axis=axis, keepdims=True)
minP = np.amin(input, axis=axis, keepdims=True)
centroid = (maxP+minP)/2
input = input - centroid
furthest_distance = np.amax(np.abs(input), axis=(axis, -1), keepdims=True)
input = input / furthest_distance
elif isinstance(input, torch.Tensor):
maxP = torch.max(input, dim=axis, keepdim=True)[0]
minP = torch.min(input, dim=axis, keepdim=True)[0]
centroid = (maxP+minP)/2
input = input - centroid
in_shape = list(input.shape[:axis])+[P*D]
furthest_distance = torch.max(torch.abs(input).view(in_shape), dim=axis, keepdim=True)[0]
furthest_distance = furthest_distance.unsqueeze(-1)
input = input / furthest_distance
return input, centroid, furthest_distance
def normalize(tensor, dim=-1):
"""normalize tensor in specified dimension"""
return torch.nn.functional.normalize(tensor, p=2, dim=dim, eps=1e-12, out=None)
def check_values(tensor):
"""return true if tensor doesn't contain NaN or Inf"""
return not (torch.any(torch.isnan(tensor)).item() or torch.any(torch.isinf(tensor)).item())
class ScatterAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, src, idx, dim, out_size, fill=0.0):
out = torch.full(out_size, fill, device=src.device, dtype=src.dtype)
ctx.save_for_backward(idx)
out.scatter_add_(dim, idx, src)
ctx.mark_non_differentiable(idx)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, ograd):
idx, = ctx.saved_tensors
grad = torch.gather(ograd, ctx.dim, idx)
return grad, None, None, None, None
_scatter_add = ScatterAdd.apply
def scatter_add(src, idx, dim, out_size=None, fill=0.0):
if out_size is None:
out_size = list(src.size())
dim_size = idx.max().item()+1
out_size[dim] = dim_size
return _scatter_add(src, idx, dim, out_size, fill)
def mean_value_coordinates_3D(query, vertices, faces, verbose=False):
"""
Tao Ju et.al. MVC for 3D triangle meshes
params:
query (B,P,3)
vertices (B,N,3)
faces (B,F,3)
return:
wj (B,P,N)
"""
B, F, _ = faces.shape
_, P, _ = query.shape
_, N, _ = vertices.shape
# u_i = p_i - x (B,P,N,3)
uj = vertices.unsqueeze(1) - query.unsqueeze(2)
# \|u_i\| (B,P,N,1)
dj = torch.norm(uj, dim=-1, p=2, keepdim=True)
uj = normalize(uj, dim=-1)
# gather triangle B,P,F,3,3
ui = torch.gather(uj.unsqueeze(2).expand(-1,-1,F,-1,-1),
3,
faces.unsqueeze(1).unsqueeze(-1).expand(-1,P,-1,-1,3))
# li = \|u_{i+1}-u_{i-1}\| (B,P,F,3)
li = torch.norm(ui[:,:,:,[1, 2, 0],:] - ui[:, :, :,[2, 0, 1],:], dim=-1, p=2)
eps = 2e-5
li = torch.where(li>=2, li-(li.detach()-(2-eps)), li)
li = torch.where(li<=-2, li-(li.detach()+(2-eps)), li)
# asin(x) is inf at +/-1
# θi = 2arcsin[li/2] (B,P,F,3)
theta_i = 2*torch.asin(li/2)
assert(check_values(theta_i))
# B,P,F,1
h = torch.sum(theta_i, dim=-1, keepdim=True)/2
# wi← sin[θi]d{i−1}d{i+1}
# (B,P,F,3) ci ← (2sin[h]sin[h−θi])/(sin[θ_{i+1}]sin[θ_{i−1}])−1
ci = 2*torch.sin(h)*torch.sin(h-theta_i)/(torch.sin(theta_i[:,:,:,[1, 2, 0]])*torch.sin(theta_i[:,:,:,[2, 0, 1]]))-1
# NOTE: because of floating point ci can be slightly larger than 1, causing problem with sqrt(1-ci^2)
# NOTE: sqrt(x)' is nan for x=0, hence use eps
eps = 1e-5
ci = torch.where(ci>=1, ci-(ci.detach()-(1-eps)), ci)
ci = torch.where(ci<=-1, ci-(ci.detach()+(1-eps)), ci)
# si← sign[det[u1,u2,u3]]sqrt(1-ci^2)
# (B,P,F)*(B,P,F,3)
si = torch.sign(torch.det(ui)).unsqueeze(-1)*torch.sqrt(1-ci**2) # sqrt gradient nan for 0
assert(check_values(si))
# (B,P,F,3)
di = torch.gather(dj.unsqueeze(2).squeeze(-1).expand(-1,-1,F,-1), 3,
faces.unsqueeze(1).expand(-1,P,-1,-1))
assert(check_values(di))
# if si.requires_grad:
# vertices.register_hook(save_grad("mvc/dv"))
# li.register_hook(save_grad("mvc/dli"))
# theta_i.register_hook(save_grad("mvc/dtheta"))
# ci.register_hook(save_grad("mvc/dci"))
# si.register_hook(save_grad("mvc/dsi"))
# di.register_hook(save_grad("mvc/ddi"))
# wi← (θi −c[i+1]θ[i−1] −c[i−1]θ[i+1])/(disin[θi+1]s[i−1])
# B,P,F,3
# CHECK is there a 2* in the denominator
wi = (theta_i-ci[:,:,:,[1,2,0]]*theta_i[:,:,:,[2,0,1]]-ci[:,:,:,[2,0,1]]*theta_i[:,:,:,[1,2,0]])/(di*torch.sin(theta_i[:,:,:,[1,2,0]])*si[:,:,:,[2,0,1]])
# if ∃i,|si| ≤ ε, set wi to 0. coplaner with T but outside
# ignore coplaner outside triangle
# alternative check
# (B,F,3,3)
# triangle_points = torch.gather(vertices.unsqueeze(1).expand(-1,F,-1,-1), 2, faces.unsqueeze(-1).expand(-1,-1,-1,3))
# # (B,P,F,3), (B,1,F,3) -> (B,P,F,1)
# determinant = dot_product(triangle_points[:,:,:,0].unsqueeze(1)-query.unsqueeze(2),
# torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0],
# triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1).unsqueeze(1), dim=-1, keepdim=True).detach()
# # (B,P,F,1)
# sqrdist = determinant*determinant / (4 * sqrNorm(torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1), keepdim=True))
wi = torch.where(torch.any(torch.abs(si) <= 1e-5, keepdim=True, dim=-1), torch.zeros_like(wi), wi)
# wi = torch.where(sqrdist <= 1e-5, torch.zeros_like(wi), wi)
# if π −h < ε, x lies on t, use 2D barycentric coordinates
# inside triangle
inside_triangle = (PI-h).squeeze(-1)<1e-4
# set all F for this P to zero
wi = torch.where(torch.any(inside_triangle, dim=-1, keepdim=True).unsqueeze(-1), torch.zeros_like(wi), wi)
# CHECK is it di https://www.cse.wustl.edu/~taoju/research/meanvalue.pdf or li http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.516.1856&rep=rep1&type=pdf
wi = torch.where(inside_triangle.unsqueeze(-1).expand(-1,-1,-1,wi.shape[-1]), torch.sin(theta_i)*di[:,:,:,[2,0,1]]*di[:,:,:,[1,2,0]], wi)
# sum over all faces face -> vertex (B,P,F*3) -> (B,P,N)
wj = scatter_add(wi.reshape(B,P,-1).contiguous(), faces.unsqueeze(1).expand(-1,P,-1,-1).reshape(B,P,-1), 2, out_size=(B,P,N))
# close to vertex (B,P,N)
close_to_point = dj.squeeze(-1) < 1e-8
# set all F for this P to zero
wj = torch.where(torch.any(close_to_point, dim=-1, keepdim=True), torch.zeros_like(wj), wj)
wj = torch.where(close_to_point, torch.ones_like(wj), wj)
# (B,P,1)
sumWj = torch.sum(wj, dim=-1, keepdim=True)
sumWj = torch.where(sumWj==0, torch.ones_like(sumWj), sumWj)
wj_normalised = wj / sumWj
# if wj.requires_grad:
# saved_variables["mvc/wi"] = wi
# wi.register_hook(save_grad("mvc/dwi"))
# wj.register_hook(save_grad("mvc/dwj"))
if verbose:
return wj_normalised, wi
else:
return wj_normalised