kywind
update
f96995c
import torch
import numpy as np
import time
import kornia
def interpolate_motions(bones, motions, relations, xyz, rot=None, quat=None, weights=None, device='cuda', step='n/a'):
# bones: (n_bones, 3)
# motions: (n_bones, 3)
# relations: (n_bones, k)
# indices: (n_bones,)
# xyz: (n_particles, 3)
# rot: (n_particles, 3, 3)
# quat: (n_particles, 4)
# weights: (n_particles, n_bones)
t0 = time.time()
n_bones, _ = bones.shape
n_particles, _ = xyz.shape
# Compute the bone transformations
bone_transforms = torch.zeros((n_bones, 4, 4), device=device)
n_adj = relations.shape[1]
adj_bones = bones[relations] - bones[:, None] # (n_bones, n_adj, 3)
adj_bones_new = (bones[relations] + motions[relations]) - (bones[:, None] + motions[:, None]) # (n_bones, n_adj, 3)
W = torch.eye(n_adj, device=device)[None].repeat(n_bones, 1, 1) # (n_bones, n_adj, n_adj)
# fit a transformation
F = adj_bones_new.permute(0, 2, 1) @ W @ adj_bones # (n_bones, 3, 3)
cov_rank = torch.linalg.matrix_rank(F) # (n_bones,)
cov_rank_3_mask = cov_rank == 3 # (n_bones,)
cov_rank_2_mask = cov_rank == 2 # (n_bones,)
cov_rank_1_mask = cov_rank == 1 # (n_bones,)
F_2_3 = F[cov_rank_2_mask | cov_rank_3_mask] # (n_bones, 3, 3)
F_1 = F[cov_rank_1_mask] # (n_bones, 3, 3)
# 2 or 3
try:
U, S, V = torch.svd(F_2_3) # S: (n_bones, 3)
S = torch.eye(3, device=device, dtype=torch.float32)[None].repeat(F_2_3.shape[0], 1, 1)
neg_det_mask = torch.linalg.det(F_2_3) < 0
if neg_det_mask.sum() > 0:
print(f'[step {step}] F det < 0 for {neg_det_mask.sum()} bones')
S[neg_det_mask, -1, -1] = -1
R = U @ S @ V.permute(0, 2, 1)
except:
print(f'[step {step}] SVD failed')
import ipdb; ipdb.set_trace()
neg_1_det_mask = torch.abs(torch.linalg.det(R) + 1) < 1e-3
pos_1_det_mask = torch.abs(torch.linalg.det(R) - 1) < 1e-3
bad_det_mask = ~(neg_1_det_mask | pos_1_det_mask)
if neg_1_det_mask.sum() > 0:
print(f'[step {step}] det -1')
S[neg_1_det_mask, -1, -1] *= -1
R = U @ S @ V.permute(0, 2, 1)
try:
assert bad_det_mask.sum() == 0
except:
print(f'[step {step}] Bad det')
import ipdb; ipdb.set_trace()
try:
if cov_rank_1_mask.sum() > 0:
print(f'[step {step}] F rank 1 for {cov_rank_1_mask.sum()} bones')
U, S, V = torch.svd(F_1) # S: (n_bones', 3)
assert torch.allclose(S[:, 1:], torch.zeros_like(S[:, 1:]))
x = torch.tensor([1., 0., 0.], device=device, dtype=torch.float32)[None].repeat(F_1.shape[0], 1) # (n_bones', 3)
axis = U[:, :, 0] # (n_bones', 3)
perp_axis = torch.linalg.cross(axis, x) # (n_bones', 3)
perp_axis_norm_mask = torch.norm(perp_axis, dim=1) < 1e-6
R = torch.zeros((F_1.shape[0], 3, 3), device=device, dtype=torch.float32)
if perp_axis_norm_mask.sum() > 0:
print(f'[step {step}] Perp axis norm 0 for {perp_axis_norm_mask.sum()} bones')
R[perp_axis_norm_mask] = torch.eye(3, device=device, dtype=torch.float32)[None].repeat(perp_axis_norm_mask.sum(), 1, 1)
perp_axis = perp_axis[~perp_axis_norm_mask] # (n_bones', 3)
x = x[~perp_axis_norm_mask] # (n_bones', 3)
perp_axis = perp_axis / torch.norm(perp_axis, dim=1, keepdim=True) # (n_bones', 3)
third_axis = torch.linalg.cross(x, perp_axis) # (n_bones', 3)
assert ((torch.norm(third_axis, dim=1) - 1).abs() < 1e-6).all()
third_axis_after = torch.linalg.cross(axis, perp_axis) # (n_bones', 3)
X = torch.stack([x, perp_axis, third_axis], dim=-1)
Y = torch.stack([axis, perp_axis, third_axis_after], dim=-1)
R[~perp_axis_norm_mask] = Y @ X.permute(0, 2, 1)
except:
R = torch.zeros((F_1.shape[0], 3, 3), device=device, dtype=torch.float32)
R[:, 0, 0] = 1
R[:, 1, 1] = 1
R[:, 2, 2] = 1
try:
bone_transforms[:, :3, :3] = R
except:
print(f'[step {step}] Bad R')
bone_transforms[:, 0, 0] = 1
bone_transforms[:, 1, 1] = 1
bone_transforms[:, 2, 2] = 1
bone_transforms[:, :3, 3] = motions
# Compute the weights
if weights is None:
weights = torch.ones((n_particles, n_bones), device=device)
dist = torch.cdist(xyz[None], bones[None])[0] # (n_particles, n_bones)
dist = torch.clamp(dist, min=1e-4)
weights = 1 / dist
weights = weights / weights.sum(dim=1, keepdim=True) # (n_particles, n_bones)
# Compute the transformed particles
xyz_transformed = torch.zeros((n_particles, n_bones, 3), device=device)
xyz_transformed = xyz[:, None] - bones[None] # (n_particles, n_bones, 3)
xyz_transformed = torch.einsum('ijk,jkl->ijl', xyz_transformed, bone_transforms[:, :3, :3].permute(0, 2, 1)) # (n_particles, n_bones, 3)
xyz_transformed = xyz_transformed + bone_transforms[:, :3, 3][None] + bones[None] # (n_particles, n_bones, 3)
xyz_transformed = (xyz_transformed * weights[:, :, None]).sum(dim=1) # (n_particles, 3)
def quaternion_multiply(q1, q2):
# q1: bsz x 4
# q2: bsz x 4
q = torch.zeros_like(q1)
q[:, 0] = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3]
q[:, 1] = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2]
q[:, 2] = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1]
q[:, 3] = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0]
return q
if quat is not None:
base_quats = kornia.geometry.conversions.rotation_matrix_to_quaternion(bone_transforms[:, :3, :3]) # (n_bones, 4)
base_quats = torch.nn.functional.normalize(base_quats, dim=-1) # (n_particles, 4)
quats = (base_quats[None] * weights[:, :, None]).sum(dim=1) # (n_particles, 4)
quats = torch.nn.functional.normalize(quats, dim=-1)
rot = quaternion_multiply(quats, quat)
# xyz_transformed: (n_particles, 3)
# rot: (n_particles, 3, 3) / (n_particles, 4)
# weights: (n_particles, n_bones)
return xyz_transformed, rot, weights