import torch.nn as nn | |
class Sequential2(nn.Sequential): | |
"""An alternative sequential container to nn.Sequential, | |
which accepts an arbitrary number of input arguments. | |
""" | |
def forward(self, *inputs): | |
for module in self._modules.values(): | |
if isinstance(inputs, tuple): | |
inputs = module(*inputs) | |
else: | |
inputs = module(inputs) | |
return inputs | |