import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical from torch.distributions.distribution import Distribution from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import ExpTransform from torch.distributions.utils import broadcast_all, clamp_probs __all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] class ExpRelaxedCategorical(Distribution): r""" Creates a ExpRelaxedCategorical parameterized by :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). Returns the log of a point in the simplex. Based on the interface to :class:`OneHotCategorical`. Implementation based on [1]. See also: :func:`torch.distributions.OneHotCategorical` Args: temperature (Tensor): relaxation temperature probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017) [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017) """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = ( constraints.real_vector ) # The true support is actually a submanifold of this. has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(ExpRelaxedCategorical, _instance) batch_shape = torch.Size(batch_shape) new.temperature = self.temperature new._categorical = self._categorical.expand(batch_shape) super(ExpRelaxedCategorical, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self._validate_args return new def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @property def param_shape(self): return self._categorical.param_shape @property def logits(self): return self._categorical.logits @property def probs(self): return self._categorical.probs def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) uniforms = clamp_probs( torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) ) gumbels = -((-(uniforms.log())).log()) scores = (self.logits + gumbels) / self.temperature return scores - scores.logsumexp(dim=-1, keepdim=True) def log_prob(self, value): K = self._categorical._num_events if self._validate_args: self._validate_sample(value) logits, value = broadcast_all(self.logits, value) log_scale = torch.full_like( self.temperature, float(K) ).lgamma() - self.temperature.log().mul(-(K - 1)) score = logits - value.mul(self.temperature) score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) return score + log_scale class RelaxedOneHotCategorical(TransformedDistribution): r""" Creates a RelaxedOneHotCategorical distribution parametrized by :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. This is a relaxed version of the :class:`OneHotCategorical` distribution, so its samples are on simplex, and are reparametrizable. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) Args: temperature (Tensor): relaxation temperature probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): base_dist = ExpRelaxedCategorical( temperature, probs, logits, validate_args=validate_args ) super().__init__(base_dist, ExpTransform(), validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) return super().expand(batch_shape, _instance=new) @property def temperature(self): return self.base_dist.temperature @property def logits(self): return self.base_dist.logits @property def probs(self): return self.base_dist.probs