# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ LayerDrop as described in https://arxiv.org/abs/1909.11556. """ import torch import torch.nn as nn class LayerDropModuleList(nn.ModuleList): """ A LayerDrop implementation based on :class:`torch.nn.ModuleList`. We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During evaluation we always iterate over all layers. Usage:: layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) for layer in layers: # this might iterate over layers 1 and 3 x = layer(x) for layer in layers: # this might iterate over all layers x = layer(x) for layer in layers: # this might not iterate over any layers x = layer(x) Args: p (float): probability of dropping out each layer modules (iterable, optional): an iterable of modules to add """ def __init__(self, p, modules=None): super().__init__(modules) self.p = p def __iter__(self): dropout_probs = torch.empty(len(self)).uniform_() for i, m in enumerate(super().__iter__()): if not self.training or (dropout_probs[i] > self.p): yield m