Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2019 Shigeki Karita | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""Repeat the same layer definition.""" | |
import torch | |
class MultiSequential(torch.nn.Sequential): | |
"""Multi-input multi-output torch.nn.Sequential.""" | |
def forward(self, *args): | |
"""Repeat.""" | |
for m in self: | |
args = m(*args) | |
return args | |
def repeat(N, fn): | |
"""Repeat module N times. | |
:param int N: repeat time | |
:param function fn: function to generate module | |
:return: repeated modules | |
:rtype: MultiSequential | |
""" | |
return MultiSequential(*[fn(n) for n in range(N)]) | |