Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| LayerDrop as described in https://arxiv.org/abs/1909.11556. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class LayerDropModuleList(nn.ModuleList): | |
| """ | |
| A LayerDrop implementation based on :class:`torch.nn.ModuleList`. | |
| We refresh the choice of which layers to drop every time we iterate | |
| over the LayerDropModuleList instance. During evaluation we always | |
| iterate over all layers. | |
| Usage:: | |
| layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) | |
| for layer in layers: # this might iterate over layers 1 and 3 | |
| x = layer(x) | |
| for layer in layers: # this might iterate over all layers | |
| x = layer(x) | |
| for layer in layers: # this might not iterate over any layers | |
| x = layer(x) | |
| Args: | |
| p (float): probability of dropping out each layer | |
| modules (iterable, optional): an iterable of modules to add | |
| """ | |
| def __init__(self, p, modules=None): | |
| super().__init__(modules) | |
| self.p = p | |
| def __iter__(self): | |
| dropout_probs = torch.empty(len(self)).uniform_() | |
| for i, m in enumerate(super().__iter__()): | |
| if not self.training or (dropout_probs[i] > self.p): | |
| yield m | |