Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
class MLP(pl.LightningModule): | |
def __init__(self, | |
filter_channels, | |
name=None, | |
res_layers=[], | |
norm='group', | |
last_op=None): | |
super(MLP, self).__init__() | |
self.filters = nn.ModuleList() | |
self.norms = nn.ModuleList() | |
self.res_layers = res_layers | |
self.norm = norm | |
self.last_op = last_op | |
self.name = name | |
self.activate = nn.LeakyReLU(inplace=True) | |
for l in range(0, len(filter_channels) - 1): | |
if l in self.res_layers: | |
self.filters.append( | |
nn.Conv1d(filter_channels[l] + filter_channels[0], | |
filter_channels[l + 1], 1)) | |
else: | |
self.filters.append( | |
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) | |
if l != len(filter_channels) - 2: | |
if norm == 'group': | |
self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) | |
elif norm == 'batch': | |
self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) | |
elif norm == 'instance': | |
self.norms.append(nn.InstanceNorm1d(filter_channels[l + | |
1])) | |
elif norm == 'weight': | |
self.filters[l] = nn.utils.weight_norm(self.filters[l], | |
name='weight') | |
# print(self.filters[l].weight_g.size(), | |
# self.filters[l].weight_v.size()) | |
def forward(self, feature): | |
''' | |
feature may include multiple view inputs | |
args: | |
feature: [B, C_in, N] | |
return: | |
[B, C_out, N] prediction | |
''' | |
y = feature | |
tmpy = feature | |
for i, f in enumerate(self.filters): | |
y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) | |
if i != len(self.filters) - 1: | |
if self.norm not in ['batch', 'group', 'instance']: | |
y = self.activate(y) | |
else: | |
y = self.activate(self.norms[i](y)) | |
if self.last_op is not None: | |
y = self.last_op(y) | |
return y | |