File size: 2,925 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from pathlib import Path
import faiss
import numpy as np
from data_processing.corpus_data_wrapper import CorpusDataWrapper
from data_processing.index_wrapper import LAYER_TEMPLATE
import argparse

# Get model from base_dir
# Use that information to get the model's configuration
# From this, get the special tokens associated with that model
# Have flag to allow model's special tokens to be ignored
# Test what items match 'bert-base-cased'

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--directory", help="Path to the directory that contains the 'embeddings' and 'headContext' folders")
    
    args = parser.parse_args()
    return args

def train_indexes(ce:CorpusDataWrapper, stepsize=100, drop_null=True):
    """

    Parameters:
    ===========
    - corpus_embedding: Wrapper around HDF5 file for easy access to data
    - stepsize: How many sentences to train with at once
    - drop_null: Don't index the embeddings of special tokens (e.g., [CLS] and [SEP]) whose spacy POS are null
    """
    NUM_LAYERS = ce.n_layers # want to account for the input layer, which for attentions + contexts is all value 0
    
    embedding_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)]
    context_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)]

    for ix in range(0, len(ce), stepsize):
        cdata = ce[ix:ix+stepsize]

        if drop_null: 
            embeddings = np.concatenate([c.zero_special_embeddings for c in cdata], axis=1)
            contexts = np.concatenate([c.zero_special_contexts for c in cdata], axis=1)
        else:
            embeddings = np.concatenate([c.embeddings for c in cdata], axis=1)
            contexts = np.concatenate([c.contexts for c in cdata], axis=1)

        for i in range(NUM_LAYERS):
            embedding_indexes[i].add(embeddings[i])
            context_indexes[i].add(contexts[i])
            
    return embedding_indexes, context_indexes

def save_indexes(idxs, outdir, base_name=LAYER_TEMPLATE):
    """Save the faiss index into a file for each index in idxs"""

    base_dir = Path(outdir)
    if not base_dir.exists(): base_dir.mkdir(exist_ok=True, parents=True)

    out_name = str(base_dir / base_name)
    for i, idx in enumerate(idxs):
        name = out_name.format(i)
        print(f"Saving to {name}")
        faiss.write_index(idx, name)

def main(basedir):
    base = Path(basedir)
    h5_fname = base / 'data.hdf5'
    corpus = CorpusDataWrapper(h5_fname)
    embedding_faiss, context_faiss = train_indexes(corpus)

    context_faiss_dir = base / "context_faiss"
    embedding_faiss_dir = base / "embedding_faiss"
    save_indexes(embedding_faiss, embedding_faiss_dir)
    save_indexes(context_faiss, context_faiss_dir)

if __name__ == "__main__":
    # Creating the indices for both the context and embeddings
    args = parse_args()

    main(args.directory)