File size: 805 Bytes
69ad385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch.nn as nn


class Highway(nn.Module):
    """
    https://arxiv.org/abs/1505.00387
    [Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)]

    discuss of Highway and ResNet
    https://www.zhihu.com/question/279426970
    """
    def __init__(self, in_size, out_size):
        super(Highway, self).__init__()
        self.H = nn.Linear(in_size, out_size)
        self.H.bias.data.zero_()
        self.T = nn.Linear(in_size, out_size)
        self.T.bias.data.fill_(-1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        H = self.relu(self.H(inputs))
        T = self.sigmoid(self.T(inputs))
        return H * T + inputs * (1.0 - T)


if __name__ == '__main__':
    pass