GPT-K / knowledge /text_db.py
cwkuo
save memory usage
40e6e04
raw
history blame
1.17 kB
import h5py
from tqdm import tqdm
import numpy as np
import codecs
from knowledge.utils import file_hash
class TextDB:
def __init__(self, text_db):
self.feature, self.text = self.load(text_db)
self.file_hash = file_hash(text_db)
def load(self, text_db):
with h5py.File(text_db, 'r') as f:
db_size = 0
for i in range(len(f)):
db_size += len(f[f"{i}/feature"])
_, d = f[f"0/feature"].shape
with h5py.File(text_db, 'r') as f:
feature = np.zeros((db_size, d), dtype=np.float16)
text = []
N = 0
for i in tqdm(range(len(f)), desc="Load text DB", dynamic_ncols=True, mininterval=1.0):
fi = f[f"{i}/feature"][:]
feature[N:N+len(fi)] = fi
N += len(fi)
text.extend(f[f"{i}/text"][:])
text = [codecs.decode(t) for t in text]
return feature, text
def __getitem__(self, idx):
f = self.feature[idx]
try:
t = [self.text[i] for i in idx]
except TypeError:
t = self.text[idx]
return f, t