import torch.nn as nn class IntermediateSequential(nn.Sequential): def __init__(self, *args, return_intermediate=False): super().__init__(*args) self.return_intermediate = return_intermediate def forward(self, input): if not self.return_intermediate: return super().forward(input) intermediate_outputs = {} output = input for name, module in self.named_children(): output = intermediate_outputs[name] = module(output) return output, intermediate_outputs