from torch import nn import torch # pool of square window of size=3, stride=2 m = nn.AvgPool2d(3, stride=2) # pool of non-square window m = nn.AvgPool2d(5) input = torch.randn(32,256, 5, 5) output = m(input) output = output.squeeze(-1).squeeze(-1) print(output.shape)