|
|
import pycpd |
|
|
from builtins import super |
|
|
import numbers |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
class EMRegistration(object): |
|
|
""" |
|
|
Expectation maximization point cloud registration. |
|
|
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm: |
|
|
https://github.com/siavashk/pycpd |
|
|
|
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
X: numpy array |
|
|
NxD array of target points. |
|
|
|
|
|
Y: numpy array |
|
|
MxD array of source points. |
|
|
|
|
|
TY: numpy array |
|
|
MxD array of transformed source points. |
|
|
|
|
|
sigma2: float (positive) |
|
|
Initial variance of the Gaussian mixture model. |
|
|
|
|
|
N: int |
|
|
Number of target points. |
|
|
|
|
|
M: int |
|
|
Number of source points. |
|
|
|
|
|
D: int |
|
|
Dimensionality of source and target points |
|
|
|
|
|
iteration: int |
|
|
The current iteration throughout registration. |
|
|
|
|
|
max_iterations: int |
|
|
Registration will terminate once the algorithm has taken this |
|
|
many iterations. |
|
|
|
|
|
tolerance: float (positive) |
|
|
Registration will terminate once the difference between |
|
|
consecutive objective function values falls within this tolerance. |
|
|
|
|
|
w: float (between 0 and 1) |
|
|
Contribution of the uniform distribution to account for outliers. |
|
|
Valid values span 0 (inclusive) and 1 (exclusive). |
|
|
|
|
|
q: float |
|
|
The objective function value that represents the misalignment between source |
|
|
and target point clouds. |
|
|
|
|
|
diff: float (positive) |
|
|
The absolute difference between the current and previous objective function values. |
|
|
|
|
|
P: numpy array |
|
|
MxN array of probabilities. |
|
|
P[m, n] represents the probability that the m-th source point |
|
|
corresponds to the n-th target point. |
|
|
|
|
|
Pt1: numpy array |
|
|
Nx1 column array. |
|
|
Multiplication result between the transpose of P and a column vector of all 1s. |
|
|
|
|
|
P1: numpy array |
|
|
Mx1 column array. |
|
|
Multiplication result between P and a column vector of all 1s. |
|
|
|
|
|
Np: float (positive) |
|
|
The sum of all elements in P. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, X, Y, sigma2=None, max_iterations=None, tolerance=None, w=None, *args, **kwargs): |
|
|
if type(X) is not np.ndarray or X.ndim != 2: |
|
|
raise ValueError( |
|
|
"The target point cloud (X) must be at a 2D numpy array.") |
|
|
|
|
|
if type(Y) is not np.ndarray or Y.ndim != 2: |
|
|
raise ValueError( |
|
|
"The source point cloud (Y) must be a 2D numpy array.") |
|
|
|
|
|
if X.shape[1] != Y.shape[1]: |
|
|
raise ValueError( |
|
|
"Both point clouds need to have the same number of dimensions.") |
|
|
|
|
|
if sigma2 is not None and (not isinstance(sigma2, numbers.Number) or sigma2 <= 0): |
|
|
raise ValueError( |
|
|
"Expected a positive value for sigma2 instead got: {}".format(sigma2)) |
|
|
|
|
|
if max_iterations is not None and (not isinstance(max_iterations, numbers.Number) or max_iterations < 0): |
|
|
raise ValueError( |
|
|
"Expected a positive integer for max_iterations instead got: {}".format(max_iterations)) |
|
|
elif isinstance(max_iterations, numbers.Number) and not isinstance(max_iterations, int): |
|
|
warn("Received a non-integer value for max_iterations: {}. Casting to integer.".format(max_iterations)) |
|
|
max_iterations = int(max_iterations) |
|
|
|
|
|
if tolerance is not None and (not isinstance(tolerance, numbers.Number) or tolerance < 0): |
|
|
raise ValueError( |
|
|
"Expected a positive float for tolerance instead got: {}".format(tolerance)) |
|
|
|
|
|
if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1): |
|
|
raise ValueError( |
|
|
"Expected a value between 0 (inclusive) and 1 (exclusive) for w instead got: {}".format(w)) |
|
|
|
|
|
self.X = X |
|
|
self.Y = Y |
|
|
self.TY = Y |
|
|
self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2 |
|
|
(self.N, self.D) = self.X.shape |
|
|
(self.M, _) = self.Y.shape |
|
|
self.tolerance = 0.001 if tolerance is None else tolerance |
|
|
self.w = 0.0 if w is None else w |
|
|
self.max_iterations = 100 if max_iterations is None else max_iterations |
|
|
self.iteration = 0 |
|
|
self.diff = np.inf |
|
|
self.q = np.inf |
|
|
self.P = np.zeros((self.M, self.N)) |
|
|
self.Pt1 = np.zeros((self.N, )) |
|
|
self.P1 = np.zeros((self.M, )) |
|
|
self.PX = np.zeros((self.M, self.D)) |
|
|
self.Np = 0 |
|
|
|
|
|
def register(self, callback=lambda **kwargs: None): |
|
|
""" |
|
|
Perform the EM registration. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
callback: function |
|
|
A function that will be called after each iteration. |
|
|
Can be used to visualize the registration process. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
self.TY: numpy array |
|
|
MxD array of transformed source points. |
|
|
|
|
|
registration_parameters: |
|
|
Returned params dependent on registration method used. |
|
|
""" |
|
|
self.transform_point_cloud() |
|
|
while self.iteration < self.max_iterations and self.diff > self.tolerance: |
|
|
self.iterate() |
|
|
if callable(callback): |
|
|
kwargs = {'iteration': self.iteration, |
|
|
'error': self.q, 'X': self.X, 'Y': self.TY} |
|
|
callback(**kwargs) |
|
|
|
|
|
return self.TY, self.get_registration_parameters() |
|
|
|
|
|
def get_registration_parameters(self): |
|
|
""" |
|
|
Placeholder for child classes. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
"Registration parameters should be defined in child classes.") |
|
|
|
|
|
def update_transform(self): |
|
|
""" |
|
|
Placeholder for child classes. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
"Updating transform parameters should be defined in child classes.") |
|
|
|
|
|
def transform_point_cloud(self): |
|
|
""" |
|
|
Placeholder for child classes. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
"Updating the source point cloud should be defined in child classes.") |
|
|
|
|
|
def update_variance(self): |
|
|
""" |
|
|
Placeholder for child classes. |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
"Updating the Gaussian variance for the mixture model should be defined in child classes.") |
|
|
|
|
|
def iterate(self): |
|
|
""" |
|
|
Perform one iteration of the EM algorithm. |
|
|
""" |
|
|
self.expectation() |
|
|
self.maximization() |
|
|
self.iteration += 1 |
|
|
|
|
|
def expectation(self): |
|
|
""" |
|
|
Compute the expectation step of the EM algorithm. |
|
|
""" |
|
|
P = np.sum((self.X[None, :, :] - self.TY[:, None, :])**2, axis=2) |
|
|
P = np.exp(-P/(2*self.sigma2)) |
|
|
c = (2*np.pi*self.sigma2)**(self.D/2)*self.w/(1. - self.w)*self.M/self.N |
|
|
|
|
|
den = np.sum(P, axis = 0, keepdims = True) |
|
|
den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c |
|
|
|
|
|
self.P = np.divide(P, den) |
|
|
self.Pt1 = np.sum(self.P, axis=0) |
|
|
self.P1 = np.sum(self.P, axis=1) |
|
|
self.Np = np.sum(self.P1) |
|
|
self.PX = np.matmul(self.P, self.X) |
|
|
|
|
|
def maximization(self): |
|
|
""" |
|
|
Compute the maximization step of the EM algorithm. |
|
|
""" |
|
|
self.update_transform() |
|
|
self.transform_point_cloud() |
|
|
self.update_variance() |
|
|
|
|
|
|
|
|
class DeformableRegistration(EMRegistration): |
|
|
""" |
|
|
Deformable registration. |
|
|
Adapted from Pure Numpy Implementation of the Coherent Point Drift Algorithm: |
|
|
https://github.com/siavashk/pycpd |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
alpha: float (positive) |
|
|
Represents the trade-off between the goodness of maximum likelihood fit and regularization. |
|
|
|
|
|
beta: float(positive) |
|
|
Width of the Gaussian kernel. |
|
|
|
|
|
low_rank: bool |
|
|
Whether to use low rank approximation. |
|
|
|
|
|
num_eig: int |
|
|
Number of eigenvectors to use in lowrank calculation. |
|
|
""" |
|
|
|
|
|
def __init__(self, alpha=None, beta=None, low_rank=False, num_eig=100, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0): |
|
|
raise ValueError( |
|
|
"Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha)) |
|
|
|
|
|
if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0): |
|
|
raise ValueError( |
|
|
"Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta)) |
|
|
|
|
|
self.alpha = 2 if alpha is None else alpha |
|
|
self.beta = 2 if beta is None else beta |
|
|
self.W = np.zeros((self.M, self.D)) |
|
|
self.G = gaussian_kernel(self.Y, self.beta) |
|
|
self.low_rank = low_rank |
|
|
self.num_eig = num_eig |
|
|
if self.low_rank is True: |
|
|
self.Q, self.S = low_rank_eigen(self.G, self.num_eig) |
|
|
self.inv_S = np.diag(1./self.S) |
|
|
self.S = np.diag(self.S) |
|
|
self.E = 0. |
|
|
|
|
|
def update_transform(self): |
|
|
""" |
|
|
Calculate a new estimate of the deformable transformation. |
|
|
See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf. |
|
|
|
|
|
""" |
|
|
if self.low_rank is False: |
|
|
A = np.dot(np.diag(self.P1), self.G) + \ |
|
|
self.alpha * self.sigma2 * np.eye(self.M) |
|
|
B = self.PX - np.dot(np.diag(self.P1), self.Y) |
|
|
self.W = np.linalg.solve(A, B) |
|
|
|
|
|
elif self.low_rank is True: |
|
|
|
|
|
|
|
|
dP = np.diag(self.P1) |
|
|
dPQ = np.matmul(dP, self.Q) |
|
|
F = self.PX - np.matmul(dP, self.Y) |
|
|
|
|
|
self.W = 1 / (self.alpha * self.sigma2) * (F - np.matmul(dPQ, ( |
|
|
np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)), |
|
|
(np.matmul(self.Q.T, F)))))) |
|
|
QtW = np.matmul(self.Q.T, self.W) |
|
|
self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW))) |
|
|
|
|
|
def transform_point_cloud(self, Y=None): |
|
|
""" |
|
|
Update a point cloud using the new estimate of the deformable transformation. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
Y: numpy array, optional |
|
|
Array of points to transform - use to predict on new set of points. |
|
|
Best for predicting on new points not used to run initial registration. |
|
|
If None, self.Y used. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
If Y is None, returns None. |
|
|
Otherwise, returns the transformed Y. |
|
|
|
|
|
|
|
|
""" |
|
|
self.W[:,2:]=0 |
|
|
if Y is not None: |
|
|
G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y) |
|
|
return Y + np.dot(G, self.W) |
|
|
else: |
|
|
if self.low_rank is False: |
|
|
self.TY = self.Y + np.dot(self.G, self.W) |
|
|
|
|
|
elif self.low_rank is True: |
|
|
self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W))) |
|
|
return |
|
|
|
|
|
|
|
|
def update_variance(self): |
|
|
""" |
|
|
Update the variance of the mixture model using the new estimate of the deformable transformation. |
|
|
See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf. |
|
|
|
|
|
""" |
|
|
qprev = self.sigma2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q = np.inf |
|
|
|
|
|
xPx = np.dot(np.transpose(self.Pt1), np.sum( |
|
|
np.multiply(self.X, self.X), axis=1)) |
|
|
yPy = np.dot(np.transpose(self.P1), np.sum( |
|
|
np.multiply(self.TY, self.TY), axis=1)) |
|
|
trPXY = np.sum(np.multiply(self.TY, self.PX)) |
|
|
|
|
|
self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D) |
|
|
|
|
|
if self.sigma2 <= 0: |
|
|
self.sigma2 = self.tolerance / 10 |
|
|
|
|
|
|
|
|
|
|
|
self.diff = np.abs(self.sigma2 - qprev) |
|
|
|
|
|
def get_registration_parameters(self): |
|
|
""" |
|
|
Return the current estimate of the deformable transformation parameters. |
|
|
|
|
|
|
|
|
Returns |
|
|
------- |
|
|
self.G: numpy array |
|
|
Gaussian kernel matrix. |
|
|
|
|
|
self.W: numpy array |
|
|
Deformable transformation matrix. |
|
|
""" |
|
|
return self.G, self.W |
|
|
|
|
|
|
|
|
|
|
|
def initialize_sigma2(X, Y): |
|
|
""" |
|
|
Initialize the variance (sigma2). |
|
|
|
|
|
param |
|
|
---------- |
|
|
X: numpy array |
|
|
NxD array of points for target. |
|
|
|
|
|
Y: numpy array |
|
|
MxD array of points for source. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
sigma2: float |
|
|
Initial variance. |
|
|
""" |
|
|
(N, D) = X.shape |
|
|
(M, _) = Y.shape |
|
|
diff = X[None, :, :] - Y[:, None, :] |
|
|
err = diff ** 2 |
|
|
return np.sum(err) / (D * M * N) |
|
|
|
|
|
|
|
|
|
|
|
def gaussian_kernel(X, beta, Y=None): |
|
|
""" |
|
|
Computes a Gaussian (RBF) kernel matrix between two sets of vectors. |
|
|
|
|
|
:param X: A numpy array of shape (n_samples_X, n_features) representing the first set of vectors. |
|
|
:param beta: The standard deviation parameter for the Gaussian kernel. It controls the spread of the kernel. |
|
|
:param Y: An optional numpy array of shape (n_samples_Y, n_features) representing the second set of vectors. |
|
|
If None, the function computes the kernel between `X` and itself (i.e., the Gram matrix). |
|
|
:return: A numpy array of shape (n_samples_X, n_samples_Y) representing the Gaussian kernel matrix. |
|
|
Each element (i, j) in the matrix is computed as: |
|
|
`exp(-||X[i] - Y[j]||^2 / (2 * beta^2))` |
|
|
""" |
|
|
|
|
|
|
|
|
if Y is None: |
|
|
Y = X |
|
|
|
|
|
|
|
|
|
|
|
diff = X[:, None, :] - Y[None, :, :] |
|
|
|
|
|
|
|
|
diff = np.square(diff) |
|
|
|
|
|
|
|
|
|
|
|
diff = np.sum(diff, axis=2) |
|
|
|
|
|
|
|
|
kernel_matrix = np.exp(-diff / (2 * beta**2)) |
|
|
|
|
|
return kernel_matrix |
|
|
|
|
|
|
|
|
|
|
|
def low_rank_eigen(G, num_eig): |
|
|
""" |
|
|
Calculate the top `num_eig` eigenvectors and eigenvalues of a given Gaussian matrix G. |
|
|
This function is useful for dimensionality reduction or when a low-rank approximation is needed. |
|
|
|
|
|
:param G: A square matrix (numpy array) for which the eigen decomposition is to be performed. |
|
|
:param num_eig: The number of top eigenvectors and eigenvalues to return, based on the magnitude of eigenvalues. |
|
|
:return: A tuple containing: |
|
|
- Q: A numpy array with shape (n, num_eig) containing the top `num_eig` eigenvectors of the matrix `G`. |
|
|
Each column in `Q` corresponds to an eigenvector. |
|
|
- S: A numpy array of shape (num_eig,) containing the top `num_eig` eigenvalues of the matrix `G`. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
S, Q = np.linalg.eigh(G) |
|
|
|
|
|
|
|
|
|
|
|
eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig]) |
|
|
|
|
|
|
|
|
Q = Q[:, eig_indices] |
|
|
|
|
|
|
|
|
S = S[eig_indices] |
|
|
|
|
|
return Q, S |
|
|
|
|
|
|
|
|
|
|
|
def find_homography_translation_rotation(src_points, dst_points): |
|
|
""" |
|
|
Find the homography between two sets of coordinates with only translation and rotation. |
|
|
|
|
|
:param src_points: A numpy array of shape (n, 2) containing source coordinates. |
|
|
:param dst_points: A numpy array of shape (n, 2) containing destination coordinates. |
|
|
:return: A 3x3 homography matrix. |
|
|
""" |
|
|
|
|
|
assert src_points.shape == dst_points.shape |
|
|
assert src_points.shape[1] == 2 |
|
|
|
|
|
|
|
|
src_centroid = np.mean(src_points, axis=0) |
|
|
dst_centroid = np.mean(dst_points, axis=0) |
|
|
|
|
|
|
|
|
centered_src_points = src_points - src_centroid |
|
|
centered_dst_points = dst_points - dst_centroid |
|
|
|
|
|
|
|
|
H = np.dot(centered_src_points.T, centered_dst_points) |
|
|
|
|
|
|
|
|
U, S, Vt = np.linalg.svd(H) |
|
|
|
|
|
|
|
|
R = np.dot(Vt.T, U.T) |
|
|
|
|
|
|
|
|
if np.linalg.det(R) < 0: |
|
|
Vt[-1, :] *= -1 |
|
|
R = np.dot(Vt.T, U.T) |
|
|
|
|
|
|
|
|
t = dst_centroid - np.dot(R, src_centroid) |
|
|
|
|
|
|
|
|
homography_matrix = np.eye(3) |
|
|
homography_matrix[0:2, 0:2] = R |
|
|
homography_matrix[0:2, 2] = t |
|
|
|
|
|
return homography_matrix |
|
|
|
|
|
|
|
|
|
|
|
def apply_homography(coordinates, H): |
|
|
""" |
|
|
Apply a 3x3 homography matrix to 2D coordinates. |
|
|
|
|
|
:param coordinates: A numpy array of shape (n, 2) containing 2D coordinates. |
|
|
:param H: A numpy array of shape (3, 3) representing the homography matrix. |
|
|
:return: A numpy array of shape (n, 2) with transformed coordinates. |
|
|
""" |
|
|
|
|
|
n = coordinates.shape[0] |
|
|
homogeneous_coords = np.hstack((coordinates, np.ones((n, 1)))) |
|
|
|
|
|
|
|
|
transformed_homogeneous = np.dot(homogeneous_coords, H.T) |
|
|
|
|
|
|
|
|
transformed_coords = transformed_homogeneous[:, :2] / transformed_homogeneous[:, [2]] |
|
|
|
|
|
return transformed_coords |
|
|
|
|
|
|
|
|
|
|
|
def align_tissue(ad_tar_coor, ad_src_coor, pca_comb_features, src_img, alpha=0.5): |
|
|
""" |
|
|
Aligns the source coordinates to the target coordinates using Coherent Point Drift (CPD) |
|
|
registration, and applies a homography transformation to warp the source coordinates accordingly. |
|
|
|
|
|
:param ad_tar_coor: Numpy array of target coordinates to which the source will be aligned. |
|
|
:param ad_src_coor: Numpy array of source coordinates that will be aligned to the target. |
|
|
:param pca_comb_features: PCA-combined feature matrix used as additional features for the alignment process. |
|
|
:param src_img: Source image to be warped based on the alignment. |
|
|
:param alpha: Regularization parameter for CPD registration, default is 0.5. |
|
|
:return: |
|
|
- cpd_coor: The new source coordinates after CPD alignment. |
|
|
- homo_coor: The source coordinates after applying the homography transformation. |
|
|
- aligned_image: The source image warped based on the homography transformation. |
|
|
""" |
|
|
|
|
|
|
|
|
ad_tar_coor_z = (ad_tar_coor - ad_tar_coor.min()) / (ad_tar_coor.max() - ad_tar_coor.min()) |
|
|
ad_src_coor_z = (ad_src_coor - ad_src_coor.min()) / (ad_src_coor.max() - ad_src_coor.min()) |
|
|
|
|
|
|
|
|
pca_comb_features_z = (pca_comb_features - pca_comb_features.min()) / (pca_comb_features.max() - pca_comb_features.min()) |
|
|
|
|
|
|
|
|
target = np.concatenate((ad_tar_coor_z, pca_comb_features_z[:ad_tar_coor.shape[0], :2]), axis=1) |
|
|
source = np.concatenate((ad_src_coor_z, pca_comb_features_z[ad_tar_coor.shape[0]:, :2]), axis=1) |
|
|
|
|
|
|
|
|
reg = DeformableRegistration(X=target, Y=source, low_rank=True, |
|
|
alpha=alpha, |
|
|
max_iterations=int(1e9), tolerance=1e-9) |
|
|
|
|
|
TY = reg.register()[0] |
|
|
|
|
|
|
|
|
cpd_coor = TY[:, :2] * (ad_tar_coor.max() - ad_tar_coor.min()) + ad_tar_coor.min() |
|
|
|
|
|
|
|
|
h = find_homography_translation_rotation(ad_src_coor, cpd_coor) |
|
|
homo_coor = apply_homography(ad_src_coor, h) |
|
|
|
|
|
|
|
|
aligned_image = cv2.warpPerspective(src_img, h, (src_img.shape[1], src_img.shape[0])) |
|
|
|
|
|
|
|
|
return cpd_coor, homo_coor, aligned_image |
|
|
|
|
|
|
|
|
|