File size: 2,720 Bytes
b91a08b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
#!/usr/bin/env python
# coding=utf-8
import torch
class ffnMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(ffnMLP, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim,bias=False)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim,bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = x.sigmoid()
return x
|