File size: 805 Bytes
69ad385 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch.nn as nn
class Highway(nn.Module):
"""
https://arxiv.org/abs/1505.00387
[Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)]
discuss of Highway and ResNet
https://www.zhihu.com/question/279426970
"""
def __init__(self, in_size, out_size):
super(Highway, self).__init__()
self.H = nn.Linear(in_size, out_size)
self.H.bias.data.zero_()
self.T = nn.Linear(in_size, out_size)
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
H = self.relu(self.H(inputs))
T = self.sigmoid(self.T(inputs))
return H * T + inputs * (1.0 - T)
if __name__ == '__main__':
pass
|