|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typing import Tuple |
|
|
|
def compute_img_bkg_seg( |
|
attentions, |
|
feats, |
|
featmap_dims, |
|
th_bkg, |
|
dim=64, |
|
epsilon: float = 1e-10, |
|
apply_weights: bool = True, |
|
) -> Tuple[torch.Tensor, float]: |
|
""" |
|
inputs |
|
- attentions [B, ] |
|
""" |
|
|
|
w_featmap, h_featmap = featmap_dims |
|
|
|
nb, nh, _ = attentions.shape[:3] |
|
|
|
att = attentions[:, :, 0, 1:].reshape(nb, nh, -1) |
|
att = att.reshape(nb, nh, w_featmap, h_featmap) |
|
|
|
|
|
|
|
threshold = torch.mean(att.reshape(nb, -1), dim=1) |
|
Q = torch.sum( |
|
att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2 |
|
) / (w_featmap * h_featmap) |
|
beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon)) |
|
|
|
|
|
descs = feats[:,1:,] |
|
if apply_weights: |
|
descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape( |
|
nb, -1, nh * dim |
|
) |
|
else: |
|
descs = (descs.reshape(nb, -1, nh, dim)).reshape( |
|
nb, -1, nh * dim |
|
) |
|
|
|
|
|
|
|
descs = F.normalize(descs, dim=-1, p=2) |
|
cos_sim = torch.bmm(descs, descs.permute(0, 2, 1)) |
|
|
|
|
|
|
|
if apply_weights: |
|
att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None] |
|
else: |
|
att = att.reshape(nb, nh, w_featmap, h_featmap) |
|
id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1) |
|
|
|
|
|
|
|
cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap) |
|
|
|
bkg_mask = ( |
|
cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape( |
|
nb, w_featmap, h_featmap |
|
) |
|
> th_bkg |
|
) |
|
|
|
return bkg_mask.float() |