Spaces:
Runtime error
Runtime error
import h5py | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from knowledge import TextDB | |
class ImageCropsIdx: | |
def __init__(self, knowledge_idx, topk_w, topk_f, topk_n): | |
topk = {"whole": topk_w, "five": topk_f, "nine": topk_n} | |
self.topk = {k: v for k, v in topk.items() if v > 0} | |
self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk) | |
def load(self, knowledge_idx, topk): | |
with h5py.File(knowledge_idx, "r") as f: | |
fdim = f.attrs["fdim"] | |
file_hash = f.attrs["file_hash"] | |
knowledge_idx_ = {} | |
for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0): | |
knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]} | |
for k, v in topk.items(): | |
knowledge_idx_[str(i)][k] = { | |
"index": f[f"{i}/{k}/index"][:, :, :v], | |
"score": f[f"{i}/{k}/score"][:, :, :v], | |
"query": f[f"{i}/{k}/query"][:] | |
} | |
knowledge_idx = {} | |
for i in knowledge_idx_.keys(): | |
for j, id in enumerate(knowledge_idx_[i]["image_ids"]): | |
knowledge_idx[id] = {} | |
for k in topk.keys(): | |
knowledge_idx[id][k] = { | |
"index": knowledge_idx_[i][k]["index"][j], | |
"score": knowledge_idx_[i][k]["score"][j], | |
"query": knowledge_idx_[i][k]["query"][j], | |
} | |
return knowledge_idx, fdim, file_hash | |
def __getitem__(self, image_id): | |
return self.knowledge_idx[image_id] | |
class KnowAugImageCrops: | |
def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False): | |
self.knowledge_db = knowledge_db | |
self.knowledge_idx = knowledge_idx | |
assert knowledge_db.file_hash == knowledge_idx.file_hash | |
self.ncrop = {"whole": 1, "five": 5, "nine": 9} | |
self.topk = knowledge_idx.topk | |
self.fdim = knowledge_idx.fdim | |
self.return_txt = return_txt | |
def __call__(self, image_id): | |
ret = {} | |
for k in self.topk.keys(): | |
ki = self.knowledge_idx[image_id][k]["index"].flatten() | |
ke, kt = self.knowledge_db[ki] | |
kq = self.knowledge_idx[image_id][k]["query"] | |
kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten() | |
ks = self.knowledge_idx[image_id][k]["score"].flatten() | |
ke = torch.FloatTensor(ke) | |
kq = torch.FloatTensor(kq) | |
kp = torch.LongTensor(kp) | |
ks = torch.FloatTensor(ks) | |
ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks} | |
if self.return_txt: | |
ret[k]["text"] = kt | |
return ret | |
class KnowAugImageCropsCombined: | |
def __init__( | |
self, | |
knwl_aug_obj: KnowAugImageCrops, | |
knwl_aug_attr: KnowAugImageCrops, | |
knwl_aug_act: KnowAugImageCrops | |
): | |
self.knwl_aug_obj = knwl_aug_obj | |
self.knwl_aug_act = knwl_aug_act | |
self.knwl_aug_attr = knwl_aug_attr | |
self.fdim = knwl_aug_obj.fdim | |
def __call__(self, image_id): | |
knwl_obj = self.knwl_aug_obj(image_id) | |
knwl_attr = self.knwl_aug_attr(image_id) | |
knwl_act = self.knwl_aug_act(image_id) | |
ret = {} | |
for k in knwl_obj.keys(): | |
ret[k] = { | |
"obj": knwl_obj[k], | |
"attr": knwl_attr[k], | |
"act": knwl_act[k] | |
} | |
return ret | |