Spaces:
Running
Running
| """ | |
| Generate multivariate von Mises Fisher samples. | |
| PyTorch implementation of the original code from: | |
| https://github.com/clara-labs/spherecluster | |
| """ | |
| import torch | |
| __all__ = ["sample_vMF"] | |
| def vMF_sampler( | |
| net, | |
| batch, | |
| ): | |
| mu, kappa = net(batch) | |
| return sample_vMF(mu.T, kappa.squeeze(1)) | |
| def vMF_mixture_sampler( | |
| net, | |
| batch, | |
| ): | |
| mu_mixture, kappa_mixture, weights = net(batch) | |
| # Sample mixture component indices based on weights | |
| indices = torch.multinomial(weights, num_samples=1).squeeze() | |
| # Select corresponding mu and kappa | |
| mu = mu_mixture[torch.arange(mu_mixture.shape[0]), indices] | |
| kappa = kappa_mixture[torch.arange(kappa_mixture.shape[0]), indices] | |
| return sample_vMF(mu.T, kappa) | |
| def sample_vMF(mu, kappa, num_samples=1): | |
| """Generate N-dimensional samples from von Mises Fisher | |
| distribution around center mu β R^N with concentration kappa. | |
| mu and kappa may be vectors, | |
| mu should have shape (N,) or (N, 1), kappa should be scalar or vector of length N. | |
| """ | |
| if len(mu.shape) == 1: | |
| mu = mu.unsqueeze(1) | |
| if isinstance(kappa, torch.Tensor): | |
| dim = mu.shape[0] | |
| assert mu.shape[1] == kappa.size(0) | |
| else: | |
| dim = mu.shape[0] | |
| mu = mu.repeat(1, num_samples) | |
| kappa = torch.full((num_samples,), kappa, device=mu.device, dtype=mu.dtype) | |
| # sample offset from center (on sphere) with spread kappa | |
| w = _sample_weight(kappa, dim) | |
| # sample a point v on the unit sphere that's orthogonal to mu | |
| v = _sample_orthonormal_to(mu) | |
| # compute new point | |
| result = v * torch.sqrt(1.0 - w**2).unsqueeze(0) + w.unsqueeze(0) * mu | |
| return result.T | |
| def _sample_weight(kappa, dim): | |
| """Rejection sampling scheme for sampling distance from center on | |
| surface of the sphere. | |
| """ | |
| dim = dim - 1 # since S^{n-1} | |
| try: | |
| size = kappa.size(0) | |
| except AttributeError: | |
| size = 1 | |
| b = dim / (torch.sqrt(4.0 * kappa**2 + dim**2) + 2 * kappa) | |
| x = (1.0 - b) / (1.0 + b) | |
| c = kappa * x + dim * torch.log(1 - x**2) | |
| w = torch.zeros_like(kappa) | |
| idx = torch.zeros_like(kappa, dtype=torch.bool) | |
| while True: | |
| where_zero = ~idx | |
| if torch.all(idx): | |
| return w | |
| z = ( | |
| torch.distributions.Beta(dim / 2.0, dim / 2.0) | |
| .sample((size,)) | |
| .to(kappa.device) | |
| ) | |
| _w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) | |
| u = torch.rand(size, device=kappa.device) | |
| _idx = kappa * _w + dim * torch.log(1.0 - x * _w) - c >= torch.log(u) | |
| if not torch.any(_idx): | |
| continue | |
| w[where_zero] = _w[where_zero] | |
| idx[_idx] = True | |
| def _sample_orthonormal_to(mu): | |
| """Sample point on sphere orthogonal to mu.""" | |
| v = torch.randn(mu.shape[0], mu.shape[1], device=mu.device) | |
| proj_mu_v = mu * ((v * mu).sum(dim=0)) / torch.norm(mu, dim=0) ** 2 | |
| orthto = v - proj_mu_v | |
| return orthto / torch.norm(orthto, dim=0) | |