Spaces:
Runtime error
Runtime error
from typing import Dict | |
import torch | |
from torch.distributions import Categorical, constraints | |
from torch.distributions.distribution import Distribution | |
__all__ = ["MixtureSameFamily"] | |
class MixtureSameFamily(Distribution): | |
r""" | |
The `MixtureSameFamily` distribution implements a (batch of) mixture | |
distribution where all component are from different parameterizations of | |
the same distribution type. It is parameterized by a `Categorical` | |
"selecting distribution" (over `k` component) and a component | |
distribution, i.e., a `Distribution` with a rightmost batch shape | |
(equal to `[k]`) which indexes each (batch of) component. | |
Examples:: | |
>>> # xdoctest: +SKIP("undefined vars") | |
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally | |
>>> # weighted normal distributions | |
>>> mix = D.Categorical(torch.ones(5,)) | |
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) | |
>>> gmm = MixtureSameFamily(mix, comp) | |
>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally | |
>>> # weighted bivariate normal distributions | |
>>> mix = D.Categorical(torch.ones(5,)) | |
>>> comp = D.Independent(D.Normal( | |
... torch.randn(5,2), torch.rand(5,2)), 1) | |
>>> gmm = MixtureSameFamily(mix, comp) | |
>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each | |
>>> # consisting of 5 random weighted bivariate normal distributions | |
>>> mix = D.Categorical(torch.rand(3,5)) | |
>>> comp = D.Independent(D.Normal( | |
... torch.randn(3,5,2), torch.rand(3,5,2)), 1) | |
>>> gmm = MixtureSameFamily(mix, comp) | |
Args: | |
mixture_distribution: `torch.distributions.Categorical`-like | |
instance. Manages the probability of selecting component. | |
The number of categories must match the rightmost batch | |
dimension of the `component_distribution`. Must have either | |
scalar `batch_shape` or `batch_shape` matching | |
`component_distribution.batch_shape[:-1]` | |
component_distribution: `torch.distributions.Distribution`-like | |
instance. Right-most batch dimension indexes component. | |
""" | |
arg_constraints: Dict[str, constraints.Constraint] = {} | |
has_rsample = False | |
def __init__( | |
self, mixture_distribution, component_distribution, validate_args=None | |
): | |
self._mixture_distribution = mixture_distribution | |
self._component_distribution = component_distribution | |
if not isinstance(self._mixture_distribution, Categorical): | |
raise ValueError( | |
" The Mixture distribution needs to be an " | |
" instance of torch.distributions.Categorical" | |
) | |
if not isinstance(self._component_distribution, Distribution): | |
raise ValueError( | |
"The Component distribution need to be an " | |
"instance of torch.distributions.Distribution" | |
) | |
# Check that batch size matches | |
mdbs = self._mixture_distribution.batch_shape | |
cdbs = self._component_distribution.batch_shape[:-1] | |
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): | |
if size1 != 1 and size2 != 1 and size1 != size2: | |
raise ValueError( | |
f"`mixture_distribution.batch_shape` ({mdbs}) is not " | |
"compatible with `component_distribution." | |
f"batch_shape`({cdbs})" | |
) | |
# Check that the number of mixture component matches | |
km = self._mixture_distribution.logits.shape[-1] | |
kc = self._component_distribution.batch_shape[-1] | |
if km is not None and kc is not None and km != kc: | |
raise ValueError( | |
f"`mixture_distribution component` ({km}) does not" | |
" equal `component_distribution.batch_shape[-1]`" | |
f" ({kc})" | |
) | |
self._num_component = km | |
event_shape = self._component_distribution.event_shape | |
self._event_ndims = len(event_shape) | |
super().__init__( | |
batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args | |
) | |
def expand(self, batch_shape, _instance=None): | |
batch_shape = torch.Size(batch_shape) | |
batch_shape_comp = batch_shape + (self._num_component,) | |
new = self._get_checked_instance(MixtureSameFamily, _instance) | |
new._component_distribution = self._component_distribution.expand( | |
batch_shape_comp | |
) | |
new._mixture_distribution = self._mixture_distribution.expand(batch_shape) | |
new._num_component = self._num_component | |
new._event_ndims = self._event_ndims | |
event_shape = new._component_distribution.event_shape | |
super(MixtureSameFamily, new).__init__( | |
batch_shape=batch_shape, event_shape=event_shape, validate_args=False | |
) | |
new._validate_args = self._validate_args | |
return new | |
def support(self): | |
# FIXME this may have the wrong shape when support contains batched | |
# parameters | |
return self._component_distribution.support | |
def mixture_distribution(self): | |
return self._mixture_distribution | |
def component_distribution(self): | |
return self._component_distribution | |
def mean(self): | |
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | |
return torch.sum( | |
probs * self.component_distribution.mean, dim=-1 - self._event_ndims | |
) # [B, E] | |
def variance(self): | |
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) | |
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | |
mean_cond_var = torch.sum( | |
probs * self.component_distribution.variance, dim=-1 - self._event_ndims | |
) | |
var_cond_mean = torch.sum( | |
probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), | |
dim=-1 - self._event_ndims, | |
) | |
return mean_cond_var + var_cond_mean | |
def cdf(self, x): | |
x = self._pad(x) | |
cdf_x = self.component_distribution.cdf(x) | |
mix_prob = self.mixture_distribution.probs | |
return torch.sum(cdf_x * mix_prob, dim=-1) | |
def log_prob(self, x): | |
if self._validate_args: | |
self._validate_sample(x) | |
x = self._pad(x) | |
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] | |
log_mix_prob = torch.log_softmax( | |
self.mixture_distribution.logits, dim=-1 | |
) # [B, k] | |
return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] | |
def sample(self, sample_shape=torch.Size()): | |
with torch.no_grad(): | |
sample_len = len(sample_shape) | |
batch_len = len(self.batch_shape) | |
gather_dim = sample_len + batch_len | |
es = self.event_shape | |
# mixture samples [n, B] | |
mix_sample = self.mixture_distribution.sample(sample_shape) | |
mix_shape = mix_sample.shape | |
# component samples [n, B, k, E] | |
comp_samples = self.component_distribution.sample(sample_shape) | |
# Gather along the k dimension | |
mix_sample_r = mix_sample.reshape( | |
mix_shape + torch.Size([1] * (len(es) + 1)) | |
) | |
mix_sample_r = mix_sample_r.repeat( | |
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es | |
) | |
samples = torch.gather(comp_samples, gather_dim, mix_sample_r) | |
return samples.squeeze(gather_dim) | |
def _pad(self, x): | |
return x.unsqueeze(-1 - self._event_ndims) | |
def _pad_mixture_dimensions(self, x): | |
dist_batch_ndims = self.batch_shape.numel() | |
cat_batch_ndims = self.mixture_distribution.batch_shape.numel() | |
pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims | |
xs = x.shape | |
x = x.reshape( | |
xs[:-1] | |
+ torch.Size(pad_ndims * [1]) | |
+ xs[-1:] | |
+ torch.Size(self._event_ndims * [1]) | |
) | |
return x | |
def __repr__(self): | |
args_string = ( | |
f"\n {self.mixture_distribution},\n {self.component_distribution}" | |
) | |
return "MixtureSameFamily" + "(" + args_string + ")" | |