kirch's picture
Duplicate from PAIR/Text2Video-Zero
508927a
# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext',
['psamask_forward', 'psamask_backward'])
class PSAMaskFunction(Function):
@staticmethod
def symbolic(g, input, psa_type, mask_size):
return g.op(
'mmcv::MMCVPSAMask',
input,
psa_type_i=psa_type,
mask_size_i=mask_size)
@staticmethod
def forward(ctx, input, psa_type, mask_size):
ctx.psa_type = psa_type
ctx.mask_size = _pair(mask_size)
ctx.save_for_backward(input)
h_mask, w_mask = ctx.mask_size
batch_size, channels, h_feature, w_feature = input.size()
assert channels == h_mask * w_mask
output = input.new_zeros(
(batch_size, h_feature * w_feature, h_feature, w_feature))
ext_module.psamask_forward(
input,
output,
psa_type=psa_type,
num_=batch_size,
h_feature=h_feature,
w_feature=w_feature,
h_mask=h_mask,
w_mask=w_mask,
half_h_mask=(h_mask - 1) // 2,
half_w_mask=(w_mask - 1) // 2)
return output
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
psa_type = ctx.psa_type
h_mask, w_mask = ctx.mask_size
batch_size, channels, h_feature, w_feature = input.size()
grad_input = grad_output.new_zeros(
(batch_size, channels, h_feature, w_feature))
ext_module.psamask_backward(
grad_output,
grad_input,
psa_type=psa_type,
num_=batch_size,
h_feature=h_feature,
w_feature=w_feature,
h_mask=h_mask,
w_mask=w_mask,
half_h_mask=(h_mask - 1) // 2,
half_w_mask=(w_mask - 1) // 2)
return grad_input, None, None, None
psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None):
super(PSAMask, self).__init__()
assert psa_type in ['collect', 'distribute']
if psa_type == 'collect':
psa_type_enum = 0
else:
psa_type_enum = 1
self.psa_type_enum = psa_type_enum
self.mask_size = mask_size
self.psa_type = psa_type
def forward(self, input):
return psa_mask(input, self.psa_type_enum, self.mask_size)
def __repr__(self):
s = self.__class__.__name__
s += f'(psa_type={self.psa_type}, '
s += f'mask_size={self.mask_size})'
return s