|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
class ThinPlateSplines: |
|
def __init__(self, ctrl_pts: tf.Tensor, target_pts: tf.Tensor, reg=0.0): |
|
""" |
|
|
|
:param ctrl_pts: [N, d] tensor of control d-dimensional points |
|
:param target_pts: [N, d] tensor of target d-dimensional points |
|
:param reg: regularization coefficient |
|
""" |
|
self.__ctrl_pts = ctrl_pts |
|
self.__target_pts = target_pts |
|
self.__reg = reg |
|
self.__num_ctrl_pts = ctrl_pts.shape[0] |
|
self.__dim = ctrl_pts.shape[1] |
|
|
|
self.__compute_coeffs() |
|
|
|
self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] |
|
|
|
def __compute_coeffs(self): |
|
target_pts_aug = tf.concat([self.__target_pts, |
|
tf.zeros([self.__dim + 1, self.__dim], dtype=self.__target_pts.dtype)], |
|
axis=0) |
|
|
|
|
|
T_i = tf.cast(tf.linalg.inv(self.__make_T()), target_pts_aug.dtype) |
|
self.__coeffs = tf.cast(tf.matmul(T_i, target_pts_aug), tf.float32) |
|
|
|
def __make_T(self): |
|
|
|
|
|
num_pts = self.__ctrl_pts.shape[0] |
|
|
|
P = tf.concat([tf.ones([self.__num_ctrl_pts, 1], dtype=tf.float32), self.__ctrl_pts], axis=1) |
|
zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float) |
|
self.__K = self.__U_dist(self.__ctrl_pts) |
|
alfa = tf.reduce_mean(self.__K) |
|
|
|
self.__K = self.__K + tf.ones_like(self.__K) * tf.pow(alfa, 2) * self.__reg |
|
|
|
|
|
|
|
|
|
return tf.concat([tf.concat([self.__K, P], axis=1), tf.concat([tf.transpose(P), zeros], axis=1)], axis=0) |
|
|
|
def __U_dist(self, ctrl_pts, int_pts=None): |
|
if int_pts is None: |
|
dist = self.__pairwise_distance_equal(ctrl_pts) |
|
else: |
|
dist = self.__pairwise_distance_different(ctrl_pts, int_pts) |
|
|
|
|
|
|
|
if ctrl_pts.shape[-1] == 2: |
|
u_dist = dist * tf.math.log(dist + 1e-6) |
|
else: |
|
|
|
|
|
u_dist = tf.sqrt(dist) |
|
|
|
|
|
|
|
return u_dist |
|
|
|
def __pairwise_distance_sq(self, pts_a, pts_b): |
|
with tf.variable_scope('pairwise_distance'): |
|
if np.all(pts_a == pts_b): |
|
|
|
pts_a_ = tf.reshape(pts_a, [-1, 1, 3]) |
|
pts_b_ = tf.reshape(pts_b, [1, -1, 3]) |
|
dist = tf.reduce_sum(tf.square(pts_a_ - pts_b_), 2) |
|
else: |
|
|
|
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1) |
|
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1) |
|
|
|
pts_a_ = tf.expand_dims(pts_a_, 1) |
|
pts_b_ = tf.expand_dims(pts_b_, 0) |
|
|
|
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True) |
|
|
|
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_ |
|
|
|
return tf.cast(dist, tf.float32) |
|
|
|
@staticmethod |
|
def __pairwise_distance_equal(pts): |
|
|
|
dist = tf.reduce_sum(tf.square(tf.reshape(pts, [-1, 1, 3]) - tf.reshape(pts, [1, -1, 3])), 2) |
|
return tf.cast(dist, tf.float32) |
|
|
|
@staticmethod |
|
def __pairwise_distance_different(pts_a, pts_b): |
|
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1) |
|
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1) |
|
|
|
pts_a_ = tf.expand_dims(pts_a_, 1) |
|
pts_b_ = tf.expand_dims(pts_b_, 0) |
|
|
|
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True) |
|
|
|
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_ |
|
return tf.cast(dist, tf.float32) |
|
|
|
def __lift_pts(self, int_pts: tf.Tensor, num_pts): |
|
|
|
|
|
|
|
|
|
|
|
|
|
int_pts_lift = tf.concat([self.__U_dist(int_pts, self.__ctrl_pts), |
|
tf.ones([num_pts, 1], dtype=tf.float32), |
|
int_pts], axis=1) |
|
return int_pts_lift |
|
|
|
@property |
|
def bending_energy(self): |
|
aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True) |
|
return tf.matmul(aux, self.__non_aff_paramms) |
|
|
|
def interpolate(self, int_points): |
|
""" |
|
|
|
:param int_points: [K, d] flattened d-points of a mesh |
|
:return: |
|
""" |
|
num_pts = tf.shape(int_points)[0] |
|
int_points_lift = self.__lift_pts(int_points, num_pts) |
|
|
|
return tf.matmul(int_points_lift, self.__coeffs) |
|
|
|
def __call__(self, int_points, num_pts, **kwargs): |
|
return self.interpolate(int_points) |
|
|
|
|
|
def thin_plate_splines_batch(ctrl_pts: tf.Tensor, target_pts: tf.Tensor, int_pts: tf.Tensor, reg=0.0): |
|
_batches = ctrl_pts.shape[0] |
|
|
|
if tf.get_default_session() is not None: |
|
print('DEBUG TIME') |
|
|
|
def tps_sample(in_data): |
|
cp, tp, ip = in_data |
|
|
|
tps = ThinPlateSplines(cp, tp, reg) |
|
interp = tps.interpolate(ip) |
|
return interp |
|
|
|
return tf.map_fn(tps_sample, elems=(ctrl_pts, target_pts, int_pts), dtype=tf.float32) |
|
|
|
|