"""Implementation of the hard Concrete distribution. Originally from: https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py """ import math import torch import torch.nn as nn class HardConcrete(nn.Module): """A HarcConcrete module. Use this module to create a mask of size N, which you can then use to perform L0 regularization. To obtain a mask, simply run a forward pass through the module with no input data. The mask is sampled in training mode, and fixed during evaluation mode, e.g.: >>> module = HardConcrete(n_in=100) >>> mask = module() >>> norm = module.l0_norm() """ def __init__( self, n_in: int, init_mean: float = 0.5, init_std: float = 0.01, temperature: float = 2/3, # from CoFi stretch: float = 0.1, eps: float = 1e-6 ) -> None: """Initialize the HardConcrete module. Parameters ---------- n_in : int The number of hard concrete variables in this mask. init_mean : float, optional Initial drop rate for hard concrete parameter, by default 0.5., init_std: float, optional Used to initialize the hard concrete parameters, by default 0.01. temperature : float, optional Temperature used to control the sharpness of the distribution, by default 1.0 stretch : float, optional Stretch the sampled value from [0, 1] to the interval [-stretch, 1 + stretch], by default 0.1. """ super().__init__() self.n_in = n_in self.limit_l = -stretch self.limit_r = 1.0 + stretch self.log_alpha = nn.Parameter(torch.zeros(n_in)) self.beta = temperature self.init_mean = init_mean self.init_std = init_std self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) self.eps = eps self.compiled_mask = None self.reset_parameters() def reset_parameters(self): """Reset the parameters of this module.""" self.compiled_mask = None mean = math.log(1 - self.init_mean) - math.log(self.init_mean) self.log_alpha.data.normal_(mean, self.init_std) def l0_norm(self) -> torch.Tensor: """Compute the expected L0 norm of this mask. Returns ------- torch.Tensor The expected L0 norm. """ return (self.log_alpha + self.bias).sigmoid().sum() def forward(self) -> torch.Tensor: """Sample a hard concrete mask. Returns ------- torch.Tensor The sampled binary mask """ if self.training: # Reset the compiled mask self.compiled_mask = None # Sample mask dynamically u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) s = s * (self.limit_r - self.limit_l) + self.limit_l mask = s.clamp(min=0., max=1.) else: # Compile new mask if not cached if self.compiled_mask is None: # Get expected sparsity expected_num_zeros = self.n_in - self.l0_norm().item() num_zeros = round(expected_num_zeros) # Approximate expected value of each mask variable z; # We use an empirically validated magic number 0.8 soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) # Prune small values to set to 0 _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) soft_mask[indices] = 0. self.compiled_mask = soft_mask mask = self.compiled_mask return mask def extra_repr(self) -> str: return str(self.n_in) def __repr__(self) -> str: return "{}({})".format(self.__class__.__name__, self.extra_repr())