Deploy_Restoration / net /IntmdSequential.py
AlexZou's picture
Upload 17 files
7eb6194
raw
history blame
No virus
557 Bytes
import torch.nn as nn
class IntermediateSequential(nn.Sequential):
def __init__(self, *args, return_intermediate=False):
super().__init__(*args)
self.return_intermediate = return_intermediate
def forward(self, input):
if not self.return_intermediate:
return super().forward(input)
intermediate_outputs = {}
output = input
for name, module in self.named_children():
output = intermediate_outputs[name] = module(output)
return output, intermediate_outputs