File size: 7,107 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch

def distance_to_similarity(distances, temperature=1.0):
    """
    Turns a distance matrix into a similarity matrix so it works with distribution-based metrics.
    """
    similarities = torch.exp(-distances / temperature)
    similarities = torch.clamp(similarities, min=1e-8)
    return similarities

#################################
##   "New Object" Detection    ##
#################################

def detect_newness_two_sided(distances, k=3, quantile=0.97):
    device = distances.device
    N_src, N_tgt = distances.shape

    topk_src_idx_t = torch.topk(distances, k, dim=0, largest=False).indices  # [k, N_tgt]
    topk_tgt_idx_s = torch.topk(distances, k, dim=1, largest=False).indices  # [N_src, k]

    src_to_tgt_mask = torch.zeros((N_src, N_tgt), device=device)
    tgt_to_src_mask = torch.zeros((N_src, N_tgt), device=device)

    row_indices = topk_src_idx_t  # [k, N_tgt]
    col_indices = torch.arange(N_tgt, device=device).unsqueeze(0).repeat(k, 1)  # [k, N_tgt]
    src_to_tgt_mask[row_indices, col_indices] = 1.0  # Assign 1.0 at the top-k positions

    row_indices = torch.arange(N_src, device=device).unsqueeze(1).repeat(1, k)  # [N_src, k]
    col_indices = topk_tgt_idx_s  # [N_src, k]
    tgt_to_src_mask[row_indices, col_indices] = 1.0  # Assign 1.0 at the top-k positions

    overlap_mask = (src_to_tgt_mask * tgt_to_src_mask).sum(dim=0) > 0  # [N_tgt]

    distances[:, overlap_mask] = 0.0

    two_sided_mask = (~overlap_mask).float()

    min_distances, _ = distances.min(dim=0)
    threshold = torch.quantile(min_distances, quantile)
    threshold_mask = (min_distances > threshold).float()

    combined_mask = two_sided_mask * threshold_mask
    return combined_mask

def detect_newness_distance(min_distances, quantile=0.97):
    """
    Old approach: threshold on min distance at a chosen percentile.
    """
    threshold = torch.quantile(min_distances, quantile)
    newness_mask = (min_distances > threshold).float()
    return newness_mask

def detect_newness_topk_margin(distances, top_k=2, quantile=0.03):
    """
    Top-k margin approach in distance space.
    distances: [N_src, N_tgt]
    Sort each column ascending => best match is index 0, second best is index 1, etc.
    A smaller margin => ambiguous => likely new.
    We threshold the margin at some percentile.
    """
    sorted_dists, _ = torch.sort(distances, dim=0)  
    best = sorted_dists[0]                        # [N_tgt]
    second_best = sorted_dists[1] if top_k >= 2 else sorted_dists[0]  # [N_tgt]
    margin = second_best - best  # [N_tgt]

    # If margin < threshold => ambiguous => "new"
    # We'll pick threshold as a quantile of margin
    threshold = torch.quantile(margin, quantile)
    newness_mask = (margin < threshold).float()
    return newness_mask

def detect_newness_entropy(distances, temperature=1.0, quantile=0.97):
    """
    Entropy-based approach. First convert distance->similarity with an exponential.
    Then normalize to get a distribution for each target patch, compute Shannon entropy.
    High entropy => new object (no strong match).
    """
    similarities = distance_to_similarity(distances, temperature=temperature)
    probs = similarities / similarities.sum(dim=0, keepdim=True)  # [N_src, N_tgt]
    # Shannon Entropy: -sum(p log p)
    entropy = -torch.sum(probs * torch.log(probs), dim=0)  # [N_tgt]

    # threshold
    threshold = torch.quantile(entropy, quantile)
    newness_mask = (entropy > threshold).float()
    return newness_mask

def detect_newness_gini(distances, temperature=1.0, quantile=0.97):
    """
    Gini impurity-based approach. Convert distances to similarities,
    get a distribution, compute Gini.
    High Gini => wide distribution => new object.
    """
    similarities = distance_to_similarity(distances, temperature=temperature)
    probs = similarities / similarities.sum(dim=0, keepdim=True)
    # Gini: sum(p_i*(1-p_i)) => high if spread out
    gini = torch.sum(probs * (1.0 - probs), dim=0)  # [N_tgt]

    threshold = torch.quantile(gini, quantile)
    newness_mask = (gini > threshold).float()
    return newness_mask

def detect_newness_kl(distances, temperature=1.0, quantile=0.97):
    """
    KL-based approach. Compare distribution to uniform => if close to uniform => new object.
    1) Convert distances -> similarities
    2) p(x) = similarities / sum(similarities)
    3) KL(p || uniform) => sum p(x) log (p(x)/(1/N_src))
    4) If p is near uniform => KL small => new object.
       We'll invert it => newness ~ 1/KL.
    """
    similarities = distance_to_similarity(distances, temperature=temperature)
    N_src = distances.shape[0]
    probs = similarities / similarities.sum(dim=0, keepdim=True)

    uniform_val = 1.0 / float(N_src)
    kl_vals = torch.sum(probs * torch.log(probs / uniform_val), dim=0)  # [N_tgt]
    inv_kl = 1.0 / (kl_vals + 1e-8)  # big => distribution is near uniform => new

    threshold = torch.quantile(inv_kl, quantile)
    newness_mask = (inv_kl > threshold).float()
    return newness_mask

def detect_newness_variation_ratio(distances, temperature=1.0, quantile=0.97):
    """
    Variation Ratio: 1 - max(prob).
    1) Convert distance->similarity
    2) p(x) = sim(x) / sum_x'(sim(x'))
    3) var_ratio = 1 - max(p)
    High var_ratio => new object.
    """
    similarities = distance_to_similarity(distances, temperature=temperature)
    probs = similarities / similarities.sum(dim=0, keepdim=True)
    max_prob, _ = torch.max(probs, dim=0)  # [N_tgt]
    var_ratio = 1.0 - max_prob

    threshold = torch.quantile(var_ratio, quantile)
    newness_mask = (var_ratio > threshold).float()
    return newness_mask


def detect_newness_two_sided_ratio(
    distances,
    top_k_ratio_quantile=0.03,
    two_sided=True
):
    """
    Two-sided matching + ratio test in distance space.

    Ratio test: For each t, let d0 = best distance, d1 = second best.
        ratio = d0 / (d1 + 1e-8).
        If ratio < ratio_threshold => ambiguous => new.
        (Typically a smaller ratio means a better match, but we invert logic:
        a patch can be "new" if the ratio is extremely small or ambiguous.)  
    """

    N_src, N_tgt = distances.shape

    # Target → Source: best match
    min_vals_t, best_s_for_t = torch.min(distances, dim=0)

    # Source → Target: best match
    min_vals_s, best_t_for_s = torch.min(distances, dim=1)

    # Two-sided consistency check
    twosided_mask = torch.zeros(N_tgt, device=distances.device)
    if two_sided:
        for t in range(N_tgt):
            s = best_s_for_t[t]
            if best_t_for_s[s] != t:
                twosided_mask[t] = 1.0

    # Ratio test: ambiguous if best match is not clearly better than second best
    sorted_dists, _ = torch.sort(distances, dim=0)
    d0 = sorted_dists[0]
    d1 = sorted_dists[1]
    ratio = d0 / (d1 + 1e-8)
    ratio_threshold = torch.quantile(ratio, top_k_ratio_quantile)
    ratio_mask = (ratio < ratio_threshold).float()

    # Combine checks (currently using only two-sided result)
    newness_mask = twosided_mask

    return newness_mask