File size: 1,252 Bytes
8437114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn


class LSTMCellWithZoneOut(nn.Module):
    """
    Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
    https://arxiv.org/abs/1606.01305
    """

    def __init__(self, prob: float, input_size: int, hidden_size: int,
                 bias: bool = True):
        super(LSTMCellWithZoneOut, self).__init__()
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
        self.prob = prob
        if prob > 1.0 or prob < 0.0:
            raise ValueError("zoneout probability must be in the range from "
                             "0.0 to 1.0.")

    def zoneout(self, h, next_h, prob):
        if isinstance(h, tuple):
            return tuple(
                [self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
            )

        if self.training:
            mask = h.new_zeros(*h.size()).bernoulli_(prob)
            return mask * h + (1 - mask) * next_h

        return prob * h + (1 - prob) * next_h

    def forward(self, x, h):
        return self.zoneout(h, self.lstm_cell(x, h), self.prob)