|
import torch |
|
from torch import nn |
|
|
|
|
|
def make_encoder(input_dim, enc_dec_dims): |
|
encoder_layers = [] |
|
decoder_layers = [] |
|
output_dim = input_dim |
|
enc_shape = enc_dec_dims[-1] |
|
for enc_dim in enc_dec_dims[:-1]: |
|
encoder_layers.extend([nn.Linear(input_dim, enc_dim), nn.SELU()]) |
|
input_dim = enc_dim |
|
|
|
encoder_layers.append(nn.Linear(input_dim, enc_shape)) |
|
|
|
enc_dec_dims = list(reversed(enc_dec_dims)) |
|
for dec_dim in enc_dec_dims[1:]: |
|
decoder_layers.extend([nn.Linear(enc_shape, dec_dim), nn.SELU()]) |
|
enc_shape = dec_dim |
|
|
|
decoder_layers.append(nn.Linear(enc_shape, output_dim)) |
|
|
|
return nn.Sequential(*encoder_layers), nn.Sequential(*decoder_layers) |
|
|
|
|
|
class FsrFgModel(nn.Module): |
|
def __init__(self, fg_input_dim, mfg_input_dim, num_input_dim, enc_dec_dims, output_dims, |
|
num_tasks, dropout, method): |
|
super(FsrFgModel, self).__init__() |
|
|
|
self.method = method |
|
if self.method == 'FG': |
|
input_dim = fg_input_dim |
|
elif self.method == 'MFG': |
|
input_dim = mfg_input_dim |
|
elif self.method == 'FGR': |
|
input_dim = fg_input_dim + mfg_input_dim |
|
else: |
|
input_dim = fg_input_dim + mfg_input_dim |
|
if self.method != 'FGR_desc': |
|
fcn_input_dim = enc_dec_dims[-1] |
|
else: |
|
fcn_input_dim = num_input_dim + enc_dec_dims[-1] |
|
self.encoder, self.decoder = make_encoder(input_dim, enc_dec_dims) |
|
self.dropout = nn.Dropout(dropout) |
|
self.predict_out_dim = num_tasks |
|
self.batch_norm = nn.BatchNorm1d(fcn_input_dim) |
|
|
|
layers = [] |
|
for output_dim in output_dims: |
|
layers.extend([nn.Linear(fcn_input_dim, output_dim), nn.SELU(), nn.BatchNorm1d(output_dim)]) |
|
fcn_input_dim = output_dim |
|
|
|
layers.extend([self.dropout, nn.Linear(fcn_input_dim, num_tasks)]) |
|
|
|
self.predictor = nn.Sequential(*layers) |
|
|
|
def forward(self, fg=None, mfg=None, num_features=None): |
|
|
|
if self.method == 'FG': |
|
z_d = self.encoder(fg) |
|
elif self.method == 'MFG': |
|
z_d = self.encoder(mfg) |
|
elif self.method == 'FGR': |
|
z_d = self.encoder(torch.cat([fg, mfg], dim=1)) |
|
else: |
|
z_d = self.encoder(torch.cat([fg, mfg], dim=1)) |
|
|
|
v_d_hat = self.decoder(z_d) |
|
|
|
if self.method == 'FGR_desc': |
|
z_d = torch.cat([z_d, num_features], dim=1) |
|
|
|
output = self.predictor(z_d) |
|
return output, v_d_hat |
|
|