Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
# https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 | |
def mask_pool(x, mask): | |
""" | |
Args: | |
x: [B, C, H, W] | |
mask: [B, Q, H, W] | |
""" | |
if not x.shape[-2:] == mask.shape[-2:]: | |
# reshape mask to x | |
mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) | |
with torch.no_grad(): | |
mask = mask.detach() | |
mask = (mask > 0).to(mask.dtype) | |
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 | |
mask_pooled_x = torch.einsum( | |
"bchw,bqhw->bqc", | |
x, | |
mask / denorm, | |
) | |
return mask_pooled_x | |