import torch tensor_interpolation = None def get_tensor_interpolation_method(): return tensor_interpolation def set_tensor_interpolation_method(is_slerp): global tensor_interpolation tensor_interpolation = slerp if is_slerp else linear def linear(v1, v2, t): return (1.0 - t) * v1 + t * v2 def slerp( v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 ) -> torch.Tensor: u0 = v0 / v0.norm() u1 = v1 / v1.norm() dot = (u0 * u1).sum() if dot.abs() > DOT_THRESHOLD: # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') return (1.0 - t) * v0 + t * v1 omega = dot.acos() return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()