File size: 379 Bytes
9b9b1dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
class ChannelCompression(nn.Module):
"""
Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs.
In: HxWxC
Out: HxWx2
"""
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
)
|