Spaces:
Build error
Build error
File size: 557 Bytes
7eb6194 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch.nn as nn
class IntermediateSequential(nn.Sequential):
def __init__(self, *args, return_intermediate=False):
super().__init__(*args)
self.return_intermediate = return_intermediate
def forward(self, input):
if not self.return_intermediate:
return super().forward(input)
intermediate_outputs = {}
output = input
for name, module in self.named_children():
output = intermediate_outputs[name] = module(output)
return output, intermediate_outputs
|