Spaces:
Runtime error
Runtime error
import torch | |
class LockedDropout(torch.nn.Module): | |
""" | |
Implementation of locked (or variational) dropout. | |
Randomly drops out entire parameters in embedding space. | |
:param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1. | |
:param batch_first: represent if the drop will perform in an ascending manner | |
:param inplace: | |
""" | |
def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False): | |
super(LockedDropout, self).__init__() | |
self.dropout_rate = dropout_rate | |
self.batch_first = batch_first | |
self.inplace = inplace | |
def forward(self, x): | |
if not self.training or not self.dropout_rate: | |
return x | |
if not self.batch_first: | |
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) | |
else: | |
m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) | |
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) | |
mask = mask.expand_as(x) | |
return mask * x | |
def extra_repr(self): | |
inplace_str = ", inplace" if self.inplace else "" | |
return "p={}{}".format(self.dropout_rate, inplace_str) | |
class WordDropout(torch.nn.Module): | |
""" | |
Implementation of word dropout. Randomly drops out entire words | |
(or characters) in embedding space. | |
""" | |
def __init__(self, dropout_rate=0.05, inplace=False): | |
super(WordDropout, self).__init__() | |
self.dropout_rate = dropout_rate | |
self.inplace = inplace | |
def forward(self, x): | |
if not self.training or not self.dropout_rate: | |
return x | |
m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate) | |
mask = torch.autograd.Variable(m, requires_grad=False) | |
return mask * x | |
def extra_repr(self): | |
inplace_str = ", inplace" if self.inplace else "" | |
return "p={}{}".format(self.dropout_rate, inplace_str) |