from torch import nn def MLP(channels: list): n = len(channels) layers = [] for i in range(1, n): layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n-1): layers.append(nn.BatchNorm1d(channels[i])) layers.append(nn.ReLU()) return nn.Sequential(*layers)