File size: 2,091 Bytes
dc7407d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from torch import nn
import torch.nn.functional as F
class LinLayers(nn.Module):
def __init__(self, args):
super(LinLayers,self).__init__()
in_dim= args.in_dim #16,
hidden_layers= args.hidden_layers #[512,256,128,2],
activations=args.activations#[nn.LeakyReLU(0.2),nn.LeakyReLU(0.2),nn.LeakyReLU(0.2)],
batchnorms=args.bns#[True,True,True],
dropouts = args.dropouts#[None, 0.2, 0.2]
assert len(hidden_layers) == len(activations) == len(batchnorms) == len(dropouts), 'dimensions mismatch!'
layers=nn.ModuleList()
if hidden_layers:
old_dim=in_dim
for idx,layer in enumerate(hidden_layers):
sub_layers = nn.ModuleList()
sub_layers.append(nn.Linear(old_dim,layer))
if batchnorms[idx] : sub_layers.append(nn.BatchNorm1d(num_features=layer))
if activations[idx] : sub_layers.append(activations[idx])
if dropouts[idx] : sub_layers.append(nn.Dropout(p=dropouts[idx]))
old_dim = layer
sub_layers = nn.Sequential(*sub_layers)
layers.append(sub_layers)
else:# for single layer
layers.append(nn.Linear(in_dim,out_dim))
if batchnorms : layers.append(nn.BatchNorm1d(num_features=out_dim))
if activations : layers.append(activations)
if dropouts : layers.append(nn.Dropout(p=dropouts))
self.layers = nn.Sequential(*layers)
def forward(self,x):
x = self.layers(x)
return x
'''
def _check_dimensions(self):
if isinstance(self.hidden_layers,list) :
assert len(self.hidden_layers)==len(self.activations)
assert len(self.hidden_layers)==len(self.batchnorms)
assert len(self.hidden_layers)==len(self.dropouts)
'''
|