Spaces:
Runtime error
Runtime error
import numpy as np | |
import hashlib | |
import torch | |
def file_hash(file): | |
# Ref: https://stackoverflow.com/a/59056837 | |
with open(file, "rb") as f: | |
hash_fn = hashlib.blake2b() | |
chunk = f.read(8192) | |
while chunk: | |
hash_fn.update(chunk) | |
chunk = f.read(8192) | |
return hash_fn.hexdigest() | |
def refine_cosine(Xa, Xq, I, device, k=None): | |
if k is not None: | |
assert k <= I.shape[1] | |
else: | |
k = I.shape[1] | |
Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d | |
Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d | |
Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d | |
sim = torch.sum(Xi * Xj, dim=-1) # bs x K | |
sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy() | |
I_refined, S_refined = [], [] | |
for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx): | |
I_refined.append(idx_i[sort_i][:k]) | |
S_refined.append(sim_i[sort_i][:k]) | |
I_refined = np.stack(I_refined) | |
S_refined = np.stack(S_refined) | |
return S_refined, I_refined | |