riccorl's picture
Upload models
8197b11
from collections import Counter
import json
import torch
from tqdm import tqdm
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
if __name__ == "__main__":
with open("frequency_blink.txt") as f_in:
frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()]
frequencies = set(frequencies[:1_000_000])
with open(
"/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt"
) as f_in:
for line in f_in:
title = line.strip().split(" <def>")[0].strip()
frequencies.add(title)
document_index = InMemoryDocumentIndex.from_pretrained(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
)
new_doc_index = {}
new_embeddings = []
for i in range(document_index.documents.get_label_size()):
doc = document_index.documents.get_label_from_index(i)
title = doc.split(" <def>")[0].strip()
if title in frequencies:
new_doc_index[doc] = len(new_doc_index)
new_embeddings.append(document_index.embeddings[i])
print(len(new_doc_index))
print(len(new_embeddings))
new_embeddings = torch.stack(new_embeddings, dim=0)
new_embeddings = new_embeddings.to(torch.float16)
print(new_embeddings.shape)
new_label_index = Labels()
new_label_index.add_labels(new_doc_index)
new_document_index = InMemoryDocumentIndex(
documents=new_label_index,
embeddings=new_embeddings,
)
new_document_index.save_pretrained(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered"
)