GPT-K / knowledge /utils.py
cwkuo
code clean up
9051af7
import numpy as np
import hashlib
import torch
def file_hash(file):
# Ref: https://stackoverflow.com/a/59056837
with open(file, "rb") as f:
hash_fn = hashlib.blake2b()
chunk = f.read(8192)
while chunk:
hash_fn.update(chunk)
chunk = f.read(8192)
return hash_fn.hexdigest()
@torch.no_grad()
def refine_cosine(Xa, Xq, I, device, k=None):
if k is not None:
assert k <= I.shape[1]
else:
k = I.shape[1]
Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d
Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d
Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d
sim = torch.sum(Xi * Xj, dim=-1) # bs x K
sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy()
I_refined, S_refined = [], []
for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx):
I_refined.append(idx_i[sort_i][:k])
S_refined.append(sim_i[sort_i][:k])
I_refined = np.stack(I_refined)
S_refined = np.stack(S_refined)
return S_refined, I_refined