File size: 328 Bytes
3eb682b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

from torch import nn




def connector(connector_type='linear', **kwargs):
    print("Build connector:", connector_type)
    if connector_type == 'linear':
        return nn.ModuleList([nn.Linear(kwargs['input_dim'], kwargs['output_dim']) for i in range(kwargs['num_layers'])])
    else:
        raise NotImplemented