GPT-K / knowledge /retrieve.py
cwkuo
code clean up
9051af7
raw
history blame
3.66 kB
import h5py
import numpy as np
from tqdm import tqdm
import torch
from knowledge import TextDB
class ImageCropsIdx:
def __init__(self, knowledge_idx, topk_w, topk_f, topk_n):
topk = {"whole": topk_w, "five": topk_f, "nine": topk_n}
self.topk = {k: v for k, v in topk.items() if v > 0}
self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk)
def load(self, knowledge_idx, topk):
with h5py.File(knowledge_idx, "r") as f:
fdim = f.attrs["fdim"]
file_hash = f.attrs["file_hash"]
knowledge_idx_ = {}
for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0):
knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]}
for k, v in topk.items():
knowledge_idx_[str(i)][k] = {
"index": f[f"{i}/{k}/index"][:, :, :v],
"score": f[f"{i}/{k}/score"][:, :, :v],
"query": f[f"{i}/{k}/query"][:]
}
knowledge_idx = {}
for i in knowledge_idx_.keys():
for j, id in enumerate(knowledge_idx_[i]["image_ids"]):
knowledge_idx[id] = {}
for k in topk.keys():
knowledge_idx[id][k] = {
"index": knowledge_idx_[i][k]["index"][j],
"score": knowledge_idx_[i][k]["score"][j],
"query": knowledge_idx_[i][k]["query"][j],
}
return knowledge_idx, fdim, file_hash
def __getitem__(self, image_id):
return self.knowledge_idx[image_id]
class KnowAugImageCrops:
def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False):
self.knowledge_db = knowledge_db
self.knowledge_idx = knowledge_idx
assert knowledge_db.file_hash == knowledge_idx.file_hash
self.ncrop = {"whole": 1, "five": 5, "nine": 9}
self.topk = knowledge_idx.topk
self.fdim = knowledge_idx.fdim
self.return_txt = return_txt
def __call__(self, image_id):
ret = {}
for k in self.topk.keys():
ki = self.knowledge_idx[image_id][k]["index"].flatten()
ke, kt = self.knowledge_db[ki]
kq = self.knowledge_idx[image_id][k]["query"]
kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten()
ks = self.knowledge_idx[image_id][k]["score"].flatten()
ke = torch.FloatTensor(ke)
kq = torch.FloatTensor(kq)
kp = torch.LongTensor(kp)
ks = torch.FloatTensor(ks)
ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks}
if self.return_txt:
ret[k]["text"] = kt
return ret
class KnowAugImageCropsCombined:
def __init__(
self,
knwl_aug_obj: KnowAugImageCrops,
knwl_aug_attr: KnowAugImageCrops,
knwl_aug_act: KnowAugImageCrops
):
self.knwl_aug_obj = knwl_aug_obj
self.knwl_aug_act = knwl_aug_act
self.knwl_aug_attr = knwl_aug_attr
self.fdim = knwl_aug_obj.fdim
def __call__(self, image_id):
knwl_obj = self.knwl_aug_obj(image_id)
knwl_attr = self.knwl_aug_attr(image_id)
knwl_act = self.knwl_aug_act(image_id)
ret = {}
for k in knwl_obj.keys():
ret[k] = {
"obj": knwl_obj[k],
"attr": knwl_attr[k],
"act": knwl_act[k]
}
return ret