Spaces:
Runtime error
Runtime error
| # Modified from https://github.com/hszhao/semseg/blob/master/lib/psa | |
| from torch import nn | |
| from torch.autograd import Function | |
| from torch.nn.modules.utils import _pair | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext('_ext', | |
| ['psamask_forward', 'psamask_backward']) | |
| class PSAMaskFunction(Function): | |
| def symbolic(g, input, psa_type, mask_size): | |
| return g.op( | |
| 'mmcv::MMCVPSAMask', | |
| input, | |
| psa_type_i=psa_type, | |
| mask_size_i=mask_size) | |
| def forward(ctx, input, psa_type, mask_size): | |
| ctx.psa_type = psa_type | |
| ctx.mask_size = _pair(mask_size) | |
| ctx.save_for_backward(input) | |
| h_mask, w_mask = ctx.mask_size | |
| batch_size, channels, h_feature, w_feature = input.size() | |
| assert channels == h_mask * w_mask | |
| output = input.new_zeros( | |
| (batch_size, h_feature * w_feature, h_feature, w_feature)) | |
| ext_module.psamask_forward( | |
| input, | |
| output, | |
| psa_type=psa_type, | |
| num_=batch_size, | |
| h_feature=h_feature, | |
| w_feature=w_feature, | |
| h_mask=h_mask, | |
| w_mask=w_mask, | |
| half_h_mask=(h_mask - 1) // 2, | |
| half_w_mask=(w_mask - 1) // 2) | |
| return output | |
| def backward(ctx, grad_output): | |
| input = ctx.saved_tensors[0] | |
| psa_type = ctx.psa_type | |
| h_mask, w_mask = ctx.mask_size | |
| batch_size, channels, h_feature, w_feature = input.size() | |
| grad_input = grad_output.new_zeros( | |
| (batch_size, channels, h_feature, w_feature)) | |
| ext_module.psamask_backward( | |
| grad_output, | |
| grad_input, | |
| psa_type=psa_type, | |
| num_=batch_size, | |
| h_feature=h_feature, | |
| w_feature=w_feature, | |
| h_mask=h_mask, | |
| w_mask=w_mask, | |
| half_h_mask=(h_mask - 1) // 2, | |
| half_w_mask=(w_mask - 1) // 2) | |
| return grad_input, None, None, None | |
| psa_mask = PSAMaskFunction.apply | |
| class PSAMask(nn.Module): | |
| def __init__(self, psa_type, mask_size=None): | |
| super(PSAMask, self).__init__() | |
| assert psa_type in ['collect', 'distribute'] | |
| if psa_type == 'collect': | |
| psa_type_enum = 0 | |
| else: | |
| psa_type_enum = 1 | |
| self.psa_type_enum = psa_type_enum | |
| self.mask_size = mask_size | |
| self.psa_type = psa_type | |
| def forward(self, input): | |
| return psa_mask(input, self.psa_type_enum, self.mask_size) | |
| def __repr__(self): | |
| s = self.__class__.__name__ | |
| s += f'(psa_type={self.psa_type}, ' | |
| s += f'mask_size={self.mask_size})' | |
| return s | |