Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
class LayerSelect(nn.Module): | |
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as | |
either (soft) weighting or (hard) selection of residual connection. | |
https://arxiv.org/abs/2009.13102 | |
""" | |
def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.): | |
super(LayerSelect, self).__init__() | |
self.layer_logits = torch.nn.Parameter( | |
torch.Tensor(num_logits, num_layers), | |
requires_grad=True, | |
) | |
self.hard_select = not soft_select | |
self.tau = sampling_tau | |
self.detach_grad = False | |
self.layer_samples = [None] * num_logits | |
def sample(self, logit_idx): | |
"""To leverage the efficiency of distributed training, samples for all | |
layers are computed at once for each logit_idx. Logits are parameters | |
learnt independent of each other. | |
Args: | |
logit_idx: The index of logit parameters used for sampling. | |
""" | |
assert logit_idx is not None | |
self.samples = self._gumbel_sigmoid( | |
self.layer_logits[logit_idx, :].detach() | |
if self.detach_grad | |
else self.layer_logits[logit_idx, :], | |
dim=-1, | |
tau=self.tau, | |
hard=self.hard_select, | |
) | |
self.layer_samples[logit_idx] = self.samples | |
def forward(self, i): | |
sample = self.samples[i] | |
return sample | |
def _gumbel_sigmoid( | |
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5 | |
): | |
# ~Gumbel(0,1) | |
gumbels1 = ( | |
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) | |
.exponential_() | |
.log() | |
) | |
gumbels2 = ( | |
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) | |
.exponential_() | |
.log() | |
) | |
# Difference of two gumbels because we apply a sigmoid | |
gumbels1 = (logits + gumbels1 - gumbels2) / tau | |
y_soft = gumbels1.sigmoid() | |
if hard: | |
# Straight through. | |
y_hard = torch.zeros_like( | |
logits, memory_format=torch.legacy_contiguous_format | |
).masked_fill(y_soft > threshold, 1.0) | |
ret = y_hard - y_soft.detach() + y_soft | |
else: | |
# Reparametrization trick. | |
ret = y_soft | |
return ret | |