socm / reference_embeddings.py
spencer's picture
add normal files
6df828c
raw history blame
No virus
1.25 kB
import argparse
from tqdm import tqdm
import faiss
from embeddings import FaissIndex
from models import CLIP
def main(file, index_type):
clip = CLIP()
with open(file) as f:
references = f.read().split("\n")
index = FaissIndex(
embedding_size=768,
faiss_index_location=f"faiss_indices/{index_type}.index",
indexer=faiss.IndexFlatIP,
)
index.reset()
if len(references) < 500:
ref_embeddings = clip.get_text_emb(references)
index.add(ref_embeddings.detach().numpy(), references)
else:
batches = list(range(0, len(references), 300)) + [len(references)]
batched_objects = []
for idx in range(0, len(batches) - 1):
batched_objects.append(references[batches[idx] : batches[idx + 1]])
for batch in tqdm(batched_objects):
ref_embeddings = clip.get_text_emb(batch)
index.add(ref_embeddings.detach().numpy(), batch)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="File containing references")
parser.add_argument("index_type", type=str, choices=["places", "objects"])
args = parser.parse_args()
main(args.file, args.index_type)