|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
class Highway(nn.Module): |
|
""" |
|
https://arxiv.org/abs/1505.00387 |
|
[Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)] |
|
|
|
discuss of Highway and ResNet |
|
https://www.zhihu.com/question/279426970 |
|
""" |
|
def __init__(self, in_size, out_size): |
|
super(Highway, self).__init__() |
|
self.H = nn.Linear(in_size, out_size) |
|
self.H.bias.data.zero_() |
|
self.T = nn.Linear(in_size, out_size) |
|
self.T.bias.data.fill_(-1) |
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, inputs): |
|
H = self.relu(self.H(inputs)) |
|
T = self.sigmoid(self.T(inputs)) |
|
return H * T + inputs * (1.0 - T) |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|