File size: 1,101 Bytes
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import hashlib
import torch


def file_hash(file):
    # Ref: https://stackoverflow.com/a/59056837
    with open(file, "rb") as f:
        hash_fn = hashlib.blake2b()
        chunk = f.read(8192)
        while chunk:
            hash_fn.update(chunk)
            chunk = f.read(8192)

    return hash_fn.hexdigest()


@torch.no_grad()
def refine_cosine(Xa, Xq, I, device, k=None):
    if k is not None:
        assert k <= I.shape[1]
    else:
        k = I.shape[1]

    Xi = torch.tensor(Xq, device=device).unsqueeze(1)  # bs x 1 x d
    Xj = torch.tensor(Xa[I.flatten()], device=device)  # K * bs x d
    Xj = Xj.reshape(*I.shape, Xq.shape[-1])  # bs x K x d

    sim = torch.sum(Xi * Xj, dim=-1)  # bs x K
    sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy()
    I_refined, S_refined = [], []
    for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx):
        I_refined.append(idx_i[sort_i][:k])
        S_refined.append(sim_i[sort_i][:k])
    I_refined = np.stack(I_refined)
    S_refined = np.stack(S_refined)
    
    return S_refined, I_refined