Spaces:
Running
Running
"""Implementation of the hard Concrete distribution. | |
Originally from: | |
https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py | |
""" | |
import math | |
import torch | |
import torch.nn as nn | |
class HardConcrete(nn.Module): | |
"""A HarcConcrete module. | |
Use this module to create a mask of size N, which you can | |
then use to perform L0 regularization. | |
To obtain a mask, simply run a forward pass through the module | |
with no input data. The mask is sampled in training mode, and | |
fixed during evaluation mode, e.g.: | |
>>> module = HardConcrete(n_in=100) | |
>>> mask = module() | |
>>> norm = module.l0_norm() | |
""" | |
def __init__( | |
self, | |
n_in: int, | |
init_mean: float = 0.5, | |
init_std: float = 0.01, | |
temperature: float = 2/3, # from CoFi | |
stretch: float = 0.1, | |
eps: float = 1e-6 | |
) -> None: | |
"""Initialize the HardConcrete module. | |
Parameters | |
---------- | |
n_in : int | |
The number of hard concrete variables in this mask. | |
init_mean : float, optional | |
Initial drop rate for hard concrete parameter, | |
by default 0.5., | |
init_std: float, optional | |
Used to initialize the hard concrete parameters, | |
by default 0.01. | |
temperature : float, optional | |
Temperature used to control the sharpness of the | |
distribution, by default 1.0 | |
stretch : float, optional | |
Stretch the sampled value from [0, 1] to the interval | |
[-stretch, 1 + stretch], by default 0.1. | |
""" | |
super().__init__() | |
self.n_in = n_in | |
self.limit_l = -stretch | |
self.limit_r = 1.0 + stretch | |
self.log_alpha = nn.Parameter(torch.zeros(n_in)) | |
self.beta = temperature | |
self.init_mean = init_mean | |
self.init_std = init_std | |
self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) | |
self.eps = eps | |
self.compiled_mask = None | |
self.reset_parameters() | |
def reset_parameters(self): | |
"""Reset the parameters of this module.""" | |
self.compiled_mask = None | |
mean = math.log(1 - self.init_mean) - math.log(self.init_mean) | |
self.log_alpha.data.normal_(mean, self.init_std) | |
def l0_norm(self) -> torch.Tensor: | |
"""Compute the expected L0 norm of this mask. | |
Returns | |
------- | |
torch.Tensor | |
The expected L0 norm. | |
""" | |
return (self.log_alpha + self.bias).sigmoid().sum() | |
def forward(self) -> torch.Tensor: | |
"""Sample a hard concrete mask. | |
Returns | |
------- | |
torch.Tensor | |
The sampled binary mask | |
""" | |
if self.training: | |
# Reset the compiled mask | |
self.compiled_mask = None | |
# Sample mask dynamically | |
u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) | |
s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) | |
s = s * (self.limit_r - self.limit_l) + self.limit_l | |
mask = s.clamp(min=0., max=1.) | |
else: | |
# Compile new mask if not cached | |
if self.compiled_mask is None: | |
# Get expected sparsity | |
expected_num_zeros = self.n_in - self.l0_norm().item() | |
num_zeros = round(expected_num_zeros) | |
# Approximate expected value of each mask variable z; | |
# We use an empirically validated magic number 0.8 | |
soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) | |
# Prune small values to set to 0 | |
_, indices = torch.topk(soft_mask, k=num_zeros, largest=False) | |
soft_mask[indices] = 0. | |
self.compiled_mask = soft_mask | |
mask = self.compiled_mask | |
return mask | |
def extra_repr(self) -> str: | |
return str(self.n_in) | |
def __repr__(self) -> str: | |
return "{}({})".format(self.__class__.__name__, self.extra_repr()) | |