File size: 4,141 Bytes
007d3b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

class RetMetric(object):
    def __init__(self, sim_mat, labels):
        self.gallery_labels, self.query_labels = labels
        self.sim_mat = sim_mat
        self.is_equal_query = False

    def recall_k(self, k=1):
        m = len(self.sim_mat)

        match_counter = 0

        for i in range(m):
            pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
            neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
            
            thresh = np.sort(pos_sim)[-2] if self.is_equal_query and len(pos_sim) > 1 else np.max(pos_sim)

            if np.sum(neg_sim > thresh) < k:
                match_counter += 1
        return float(match_counter) / m

class Prev_RetMetric(object):
    def __init__(self, feats, labels, cl2cl=True):

        if len(feats) == 2 and type(feats) == list:
            """
            feats = [gallery_feats, query_feats]
            labels = [gallery_labels, query_labels]
            """
            self.is_equal_query = False

            self.gallery_feats, self.query_feats = feats
            self.gallery_labels, self.query_labels = labels

        else:
            self.is_equal_query = True
            self.gallery_feats = self.query_feats = feats
            self.gallery_labels = self.query_labels = labels

        self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
        if cl2cl:
            self.sim_mat = self.sim_mat * (1 - np.identity(self.sim_mat.shape[0]))
        
    def recall_k(self, k=1):
        m = len(self.sim_mat)

        match_counter = 0

        for i in range(m):
            pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
            neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]

            thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)

            if np.sum(neg_sim > thresh) < k:
                match_counter += 1
        return float(match_counter) / m
    
def compute_recall_at_k(similarity_matrix, p_labels, g_labels, k):
    num_probes = p_labels.size(0)
    recall_at_k = 0.0
    for i in range(num_probes):
        probe_label = p_labels[i]
        sim_scores = similarity_matrix[i]
        sorted_indices = torch.argsort(sim_scores, descending=True)
        top_k_indices = sorted_indices[:k]
        correct_in_top_k = any(g_labels[idx] == probe_label for idx in top_k_indices)
        recall_at_k += correct_in_top_k
    recall_at_k /= num_probes
    return recall_at_k

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def l2_norm(input):
    input_size = input.size()
    buffer = torch.pow(input, 2)
    normp = torch.sum(buffer, 1).add_(1e-12)
    norm = torch.sqrt(normp)
    _output = torch.div(input, norm.view(-1, 1).expand_as(input))
    output = _output.view(input_size)
    
    return output

def compute_sharded_cosine_similarity(tensor1, tensor2, shard_size):
    B, T, D = tensor1.shape
    average_sim_matrix = torch.zeros((B, B), device=tensor1.device)

    for start_idx1 in tqdm(range(0, B, shard_size)):
        end_idx1 = min(start_idx1 + shard_size, B)

        for start_idx2 in range(0, B, shard_size):
            end_idx2 = min(start_idx2 + shard_size, B)

            # Get the shard
            shard_tensor1 = tensor1[start_idx1:end_idx1]
            shard_tensor2 = tensor2[start_idx2:end_idx2]

            # Reshape and expand
            shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
            shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)

            # Compute cosine similarity for the shard
            shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)

            # Sum up the cosine similarities
            average_sim_matrix[start_idx1:end_idx1, start_idx2:end_idx2] += torch.sum(shard_cos_sim, dim=[2, 3])

    # Normalize by the total number of elements (T*T)
    average_sim_matrix /= (T * T)

    return average_sim_matrix