import torch import torch.nn as nn class ChannelCompression(nn.Module): """ Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs. In: HxWxC Out: HxWx2 """ def forward(self, x): return torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )