# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa from torch import nn from torch.autograd import Function from torch.nn.modules.utils import _pair from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', ['psamask_forward', 'psamask_backward']) class PSAMaskFunction(Function): @staticmethod def symbolic(g, input, psa_type, mask_size): return g.op( 'mmcv::MMCVPSAMask', input, psa_type_i=psa_type, mask_size_i=mask_size) @staticmethod def forward(ctx, input, psa_type, mask_size): ctx.psa_type = psa_type ctx.mask_size = _pair(mask_size) ctx.save_for_backward(input) h_mask, w_mask = ctx.mask_size batch_size, channels, h_feature, w_feature = input.size() assert channels == h_mask * w_mask output = input.new_zeros( (batch_size, h_feature * w_feature, h_feature, w_feature)) ext_module.psamask_forward( input, output, psa_type=psa_type, num_=batch_size, h_feature=h_feature, w_feature=w_feature, h_mask=h_mask, w_mask=w_mask, half_h_mask=(h_mask - 1) // 2, half_w_mask=(w_mask - 1) // 2) return output @staticmethod def backward(ctx, grad_output): input = ctx.saved_tensors[0] psa_type = ctx.psa_type h_mask, w_mask = ctx.mask_size batch_size, channels, h_feature, w_feature = input.size() grad_input = grad_output.new_zeros( (batch_size, channels, h_feature, w_feature)) ext_module.psamask_backward( grad_output, grad_input, psa_type=psa_type, num_=batch_size, h_feature=h_feature, w_feature=w_feature, h_mask=h_mask, w_mask=w_mask, half_h_mask=(h_mask - 1) // 2, half_w_mask=(w_mask - 1) // 2) return grad_input, None, None, None psa_mask = PSAMaskFunction.apply class PSAMask(nn.Module): def __init__(self, psa_type, mask_size=None): super(PSAMask, self).__init__() assert psa_type in ['collect', 'distribute'] if psa_type == 'collect': psa_type_enum = 0 else: psa_type_enum = 1 self.psa_type_enum = psa_type_enum self.mask_size = mask_size self.psa_type = psa_type def forward(self, input): return psa_mask(input, self.psa_type_enum, self.mask_size) def __repr__(self): s = self.__class__.__name__ s += f'(psa_type={self.psa_type}, ' s += f'mask_size={self.mask_size})' return s