File size: 2,895 Bytes
25cae60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright 2022 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F

from typing import Tuple

def compute_img_bkg_seg(
    attentions,
    feats,
    featmap_dims,
    th_bkg,
    dim=64,
    epsilon: float = 1e-10,
    apply_weights: bool = True,
) -> Tuple[torch.Tensor, float]:
    """
    inputs
       - attentions [B, ]
    """
    
    w_featmap, h_featmap = featmap_dims

    nb, nh, _ = attentions.shape[:3]
    # we keep only the output patch attention
    att = attentions[:, :, 0, 1:].reshape(nb, nh, -1)
    att = att.reshape(nb, nh, w_featmap, h_featmap)

    # -----------------------------------------------
    # Inspired by CroW sparsity channel weighting of each head CroW, Kalantidis etal.
    threshold = torch.mean(att.reshape(nb, -1), dim=1)  # Find threshold per image
    Q = torch.sum(
        att.reshape(nb, nh, w_featmap * h_featmap) > threshold[:, None, None], axis=2
    ) / (w_featmap * h_featmap)
    beta = torch.log(torch.sum(Q + epsilon, dim=1)[:, None] / (Q + epsilon))

    # Weight features based on attention sparsity
    descs = feats[:,1:,]
    if apply_weights:
        descs = (descs.reshape(nb, -1, nh, dim) * beta[:, None, :, None]).reshape(
            nb, -1, nh * dim
        )
    else:
        descs = (descs.reshape(nb, -1, nh, dim)).reshape(
            nb, -1, nh * dim
        )

    # -----------------------------------------------
    # Compute cosine-similarities
    descs = F.normalize(descs, dim=-1, p=2)
    cos_sim = torch.bmm(descs, descs.permute(0, 2, 1))

    # -----------------------------------------------
    # Find pixel with least amount of attention
    if apply_weights:
        att = att.reshape(nb, nh, w_featmap, h_featmap) * beta[:, :, None, None]
    else:
        att = att.reshape(nb, nh, w_featmap, h_featmap) 
    id_pixel_ref = torch.argmin(torch.sum(att, axis=1).reshape(nb, -1), dim=-1)

    # -----------------------------------------------
    # Mask of definitely background pixels: 1 on the background
    cos_sim = cos_sim.reshape(nb, -1, w_featmap * h_featmap)

    bkg_mask = (
        cos_sim[torch.arange(cos_sim.size(0)), id_pixel_ref, :].reshape(
            nb, w_featmap, h_featmap
        )
        > th_bkg
    )  # mask to be used to remove background

    return bkg_mask.float()