| from torch import Tensor |
|
|
| import torch |
|
|
| class Functional: |
| @staticmethod |
| def slerp(low:Tensor, |
| high:Tensor, |
| val:float = 0.5 |
| ): |
| ''' |
| Spherical Linear Interpolation (Slerp) |
| Slerp(q_0,q_1;t) = q_0(q_0^-1 q_1)^t |
| = ( sin(1-t) theta ) / sin(theta) * q_0 * sin(t * theta)/sin(theta) * q_1 |
| where dot_product(q_0,q_1) = cos(theta) |
| |
| theta = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high))) |
| so = np.sin(theta) |
| return np.sin((1.0-val)*theta) / so * low + np.sin(val*theta)/so * high |
| ''' |
| assert tuple(low.shape) == tuple(high.shape), f'low shape({low.shape}) must be same as high shape({high.shape})' |
| feature_shape:tuple = tuple(low.shape) |
| |
| low_1d:Tensor = low.reshape(feature_shape[0],-1) |
| high_1d:Tensor = high.reshape(feature_shape[0],-1) |
| low_norm = low_1d/torch.norm(low_1d, dim=1, keepdim=True) |
| high_norm = high_1d/torch.norm(high_1d, dim=1, keepdim=True) |
|
|
| dot_product = (low_norm*high_norm).sum(dim = 1) |
| theta = torch.acos(dot_product) |
| so = torch.sin(theta) |
| res = (torch.sin((1.0-val)*theta)/so).unsqueeze(1)*low_1d + (torch.sin(val*theta)/so).unsqueeze(1) * high_1d |
| return res.reshape(feature_shape) |