TTP / mmpretrain /models /utils /channel_shuffle.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
889 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def channel_shuffle(x, groups):
"""Channel Shuffle operation.
This function enables cross-group information flow for multiple groups
convolution layers.
Args:
x (Tensor): The input tensor.
groups (int): The number of groups to divide the input tensor
in the channel dimension.
Returns:
Tensor: The output tensor after channel shuffle operation.
"""
batch_size, num_channels, height, width = x.size()
assert (num_channels % groups == 0), ('num_channels should be '
'divisible by groups')
channels_per_group = num_channels // groups
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch_size, -1, height, width)
return x