Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class MLP(nn.Module): | |
def __init__(self, dims, last_op=None): | |
super(MLP, self).__init__() | |
self.dims = dims | |
self.skip_layer = [int(len(dims) / 2)] | |
self.last_op = last_op | |
self.layers = [] | |
for l in range(0, len(dims) - 1): | |
if l in self.skip_layer: | |
self.layers.append(nn.Conv1d(dims[l] + dims[0], dims[l + 1], 1)) | |
else: | |
self.layers.append(nn.Conv1d(dims[l], dims[l + 1], 1)) | |
self.add_module("conv%d" % l, self.layers[l]) | |
def forward(self, latet_code, return_all=False): | |
y = latet_code | |
tmpy = latet_code | |
y_list = [] | |
for l, f in enumerate(self.layers): | |
if l in self.skip_layer: | |
y = self._modules['conv' + str(l)](torch.cat([y, tmpy], 1)) | |
else: | |
y = self._modules['conv' + str(l)](y) | |
if l != len(self.layers) - 1: | |
y = F.leaky_relu(y) | |
if self.last_op: | |
y = self.last_op(y) | |
y_list.append(y) | |
if return_all: | |
return y_list | |
else: | |
return y |