|
r""" |
|
The ``distributions`` package contains parameterizable probability distributions |
|
and sampling functions. This allows the construction of stochastic computation |
|
graphs and stochastic gradient estimators for optimization. This package |
|
generally follows the design of the `TensorFlow Distributions`_ package. |
|
|
|
.. _`TensorFlow Distributions`: |
|
https://arxiv.org/abs/1711.10604 |
|
|
|
It is not possible to directly backpropagate through random samples. However, |
|
there are two main methods for creating surrogate functions that can be |
|
backpropagated through. These are the score function estimator/likelihood ratio |
|
estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly |
|
seen as the basis for policy gradient methods in reinforcement learning, and the |
|
pathwise derivative estimator is commonly seen in the reparameterization trick |
|
in variational autoencoders. Whilst the score function only requires the value |
|
of samples :math:`f(x)`, the pathwise derivative requires the derivative |
|
:math:`f'(x)`. The next sections discuss these two in a reinforcement learning |
|
example. For more details see |
|
`Gradient Estimation Using Stochastic Computation Graphs`_ . |
|
|
|
.. _`Gradient Estimation Using Stochastic Computation Graphs`: |
|
https://arxiv.org/abs/1506.05254 |
|
|
|
Score function |
|
^^^^^^^^^^^^^^ |
|
|
|
When the probability density function is differentiable with respect to its |
|
parameters, we only need :meth:`~torch.distributions.Distribution.sample` and |
|
:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: |
|
|
|
.. math:: |
|
|
|
\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} |
|
|
|
where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, |
|
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of |
|
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. |
|
|
|
In practice we would sample an action from the output of a network, apply this |
|
action in an environment, and then use ``log_prob`` to construct an equivalent |
|
loss function. Note that we use a negative because optimizers use gradient |
|
descent, whilst the rule above assumes gradient ascent. With a categorical |
|
policy, the code for implementing REINFORCE would be as follows:: |
|
|
|
probs = policy_network(state) |
|
# Note that this is equivalent to what used to be called multinomial |
|
m = Categorical(probs) |
|
action = m.sample() |
|
next_state, reward = env.step(action) |
|
loss = -m.log_prob(action) * reward |
|
loss.backward() |
|
|
|
Pathwise derivative |
|
^^^^^^^^^^^^^^^^^^^ |
|
|
|
The other way to implement these stochastic/policy gradients would be to use the |
|
reparameterization trick from the |
|
:meth:`~torch.distributions.Distribution.rsample` method, where the |
|
parameterized random variable can be constructed via a parameterized |
|
deterministic function of a parameter-free random variable. The reparameterized |
|
sample therefore becomes differentiable. The code for implementing the pathwise |
|
derivative would be as follows:: |
|
|
|
params = policy_network(state) |
|
m = Normal(*params) |
|
# Any distribution with .has_rsample == True could work based on the application |
|
action = m.rsample() |
|
next_state, reward = env.step(action) # Assuming that reward is differentiable |
|
loss = -reward |
|
loss.backward() |
|
""" |
|
|
|
from .bernoulli import Bernoulli |
|
from .beta import Beta |
|
from .binomial import Binomial |
|
from .categorical import Categorical |
|
from .cauchy import Cauchy |
|
from .chi2 import Chi2 |
|
from .constraint_registry import biject_to, transform_to |
|
from .continuous_bernoulli import ContinuousBernoulli |
|
from .dirichlet import Dirichlet |
|
from .distribution import Distribution |
|
from .exp_family import ExponentialFamily |
|
from .exponential import Exponential |
|
from .fishersnedecor import FisherSnedecor |
|
from .gamma import Gamma |
|
from .geometric import Geometric |
|
from .gumbel import Gumbel |
|
from .half_cauchy import HalfCauchy |
|
from .half_normal import HalfNormal |
|
from .independent import Independent |
|
from .kl import kl_divergence, register_kl, _add_kl_info |
|
from .kumaraswamy import Kumaraswamy |
|
from .laplace import Laplace |
|
from .lkj_cholesky import LKJCholesky |
|
from .log_normal import LogNormal |
|
from .logistic_normal import LogisticNormal |
|
from .lowrank_multivariate_normal import LowRankMultivariateNormal |
|
from .mixture_same_family import MixtureSameFamily |
|
from .multinomial import Multinomial |
|
from .multivariate_normal import MultivariateNormal |
|
from .negative_binomial import NegativeBinomial |
|
from .normal import Normal |
|
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough |
|
from .pareto import Pareto |
|
from .poisson import Poisson |
|
from .relaxed_bernoulli import RelaxedBernoulli |
|
from .relaxed_categorical import RelaxedOneHotCategorical |
|
from .studentT import StudentT |
|
from .transformed_distribution import TransformedDistribution |
|
from .transforms import * |
|
from .uniform import Uniform |
|
from .von_mises import VonMises |
|
from .weibull import Weibull |
|
from .wishart import Wishart |
|
from . import transforms |
|
|
|
_add_kl_info() |
|
del _add_kl_info |
|
|
|
__all__ = [ |
|
'Bernoulli', |
|
'Beta', |
|
'Binomial', |
|
'Categorical', |
|
'Cauchy', |
|
'Chi2', |
|
'ContinuousBernoulli', |
|
'Dirichlet', |
|
'Distribution', |
|
'Exponential', |
|
'ExponentialFamily', |
|
'FisherSnedecor', |
|
'Gamma', |
|
'Geometric', |
|
'Gumbel', |
|
'HalfCauchy', |
|
'HalfNormal', |
|
'Independent', |
|
'Kumaraswamy', |
|
'LKJCholesky', |
|
'Laplace', |
|
'LogNormal', |
|
'LogisticNormal', |
|
'LowRankMultivariateNormal', |
|
'MixtureSameFamily', |
|
'Multinomial', |
|
'MultivariateNormal', |
|
'NegativeBinomial', |
|
'Normal', |
|
'OneHotCategorical', |
|
'OneHotCategoricalStraightThrough', |
|
'Pareto', |
|
'RelaxedBernoulli', |
|
'RelaxedOneHotCategorical', |
|
'StudentT', |
|
'Poisson', |
|
'Uniform', |
|
'VonMises', |
|
'Weibull', |
|
'Wishart', |
|
'TransformedDistribution', |
|
'biject_to', |
|
'kl_divergence', |
|
'register_kl', |
|
'transform_to', |
|
] |
|
__all__.extend(transforms.__all__) |
|
|