Spaces:
Runtime error
Runtime error
File size: 694 Bytes
2cb106d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
# Written by Shigeki Karita, 2019
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# Adapted by Florian Lux, 2021
import torch
class MultiSequential(torch.nn.Sequential):
"""
Multi-input multi-output torch.nn.Sequential.
"""
def forward(self, *args):
"""
Repeat.
"""
for m in self:
args = m(*args)
return args
def repeat(N, fn):
"""
Repeat module N times.
Args:
N (int): Number of repeat time.
fn (Callable): Function to generate module.
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
|