TheComputerMan commited on
Commit
5e25351
1 Parent(s): 520a0ed

Upload MultiSequential.py

Browse files
Files changed (1) hide show
  1. MultiSequential.py +33 -0
MultiSequential.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class MultiSequential(torch.nn.Sequential):
9
+ """
10
+ Multi-input multi-output torch.nn.Sequential.
11
+ """
12
+
13
+ def forward(self, *args):
14
+ """
15
+ Repeat.
16
+ """
17
+ for m in self:
18
+ args = m(*args)
19
+ return args
20
+
21
+
22
+ def repeat(N, fn):
23
+ """
24
+ Repeat module N times.
25
+
26
+ Args:
27
+ N (int): Number of repeat time.
28
+ fn (Callable): Function to generate module.
29
+
30
+ Returns:
31
+ MultiSequential: Repeated model instance.
32
+ """
33
+ return MultiSequential(*[fn(n) for n in range(N)])