File size: 2,720 Bytes
b91a08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#!/usr/bin/env python
# coding=utf-8
import torch

class ffnMLP(nn.Module):                                                                                                                                                                                   
    def __init__(self, input_dim, hidden_dim, output_dim):                                                                                                                                                 
        super(ffnMLP, self).__init__()                                                                                                                                                                     
        self.fc1 = nn.Linear(input_dim, hidden_dim,bias=False)                                                                                                                                             
        self.relu = nn.ReLU()                                                                                                                                                                              
        self.fc2 = nn.Linear(hidden_dim, output_dim,bias=False)                                                                                                                                            
                                                                                                                                                                                                           
                                                                                                                                                                                                           
    def forward(self, x):                                                                                                                                                                                  
        x = self.fc1(x)                                                                                                                                                                                    
        x = self.relu(x)                                                                                                                                                                                   
        x = self.fc2(x)                                                                                                                                                                                    
        x = x.sigmoid()                                                                                                                                                                                    
        return x