exbert / server /data_processing /create_faiss.py
bhoov's picture
First commit
63858e7
raw
history blame
2.93 kB
from pathlib import Path
import faiss
import numpy as np
from data_processing.corpus_data_wrapper import CorpusDataWrapper
from data_processing.index_wrapper import LAYER_TEMPLATE
import argparse
# Get model from base_dir
# Use that information to get the model's configuration
# From this, get the special tokens associated with that model
# Have flag to allow model's special tokens to be ignored
# Test what items match 'bert-base-cased'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--directory", help="Path to the directory that contains the 'embeddings' and 'headContext' folders")
args = parser.parse_args()
return args
def train_indexes(ce:CorpusDataWrapper, stepsize=100, drop_null=True):
"""
Parameters:
===========
- corpus_embedding: Wrapper around HDF5 file for easy access to data
- stepsize: How many sentences to train with at once
- drop_null: Don't index the embeddings of special tokens (e.g., [CLS] and [SEP]) whose spacy POS are null
"""
NUM_LAYERS = ce.n_layers # want to account for the input layer, which for attentions + contexts is all value 0
embedding_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)]
context_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)]
for ix in range(0, len(ce), stepsize):
cdata = ce[ix:ix+stepsize]
if drop_null:
embeddings = np.concatenate([c.zero_special_embeddings for c in cdata], axis=1)
contexts = np.concatenate([c.zero_special_contexts for c in cdata], axis=1)
else:
embeddings = np.concatenate([c.embeddings for c in cdata], axis=1)
contexts = np.concatenate([c.contexts for c in cdata], axis=1)
for i in range(NUM_LAYERS):
embedding_indexes[i].add(embeddings[i])
context_indexes[i].add(contexts[i])
return embedding_indexes, context_indexes
def save_indexes(idxs, outdir, base_name=LAYER_TEMPLATE):
"""Save the faiss index into a file for each index in idxs"""
base_dir = Path(outdir)
if not base_dir.exists(): base_dir.mkdir(exist_ok=True, parents=True)
out_name = str(base_dir / base_name)
for i, idx in enumerate(idxs):
name = out_name.format(i)
print(f"Saving to {name}")
faiss.write_index(idx, name)
def main(basedir):
base = Path(basedir)
h5_fname = base / 'data.hdf5'
corpus = CorpusDataWrapper(h5_fname)
embedding_faiss, context_faiss = train_indexes(corpus)
context_faiss_dir = base / "context_faiss"
embedding_faiss_dir = base / "embedding_faiss"
save_indexes(embedding_faiss, embedding_faiss_dir)
save_indexes(context_faiss, context_faiss_dir)
if __name__ == "__main__":
# Creating the indices for both the context and embeddings
args = parse_args()
main(args.directory)