relik-entity-linking / scripts /filter_docs.py
riccorl's picture
Upload models
8197b11
raw history blame
No virus
1.74 kB
from collections import Counter
import json
import torch
from tqdm import tqdm
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
if __name__ == "__main__":
with open("frequency_blink.txt") as f_in:
frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()]
frequencies = set(frequencies[:1_000_000])
with open(
"/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt"
) as f_in:
for line in f_in:
title = line.strip().split(" <def>")[0].strip()
frequencies.add(title)
document_index = InMemoryDocumentIndex.from_pretrained(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
)
new_doc_index = {}
new_embeddings = []
for i in range(document_index.documents.get_label_size()):
doc = document_index.documents.get_label_from_index(i)
title = doc.split(" <def>")[0].strip()
if title in frequencies:
new_doc_index[doc] = len(new_doc_index)
new_embeddings.append(document_index.embeddings[i])
print(len(new_doc_index))
print(len(new_embeddings))
new_embeddings = torch.stack(new_embeddings, dim=0)
new_embeddings = new_embeddings.to(torch.float16)
print(new_embeddings.shape)
new_label_index = Labels()
new_label_index.add_labels(new_doc_index)
new_document_index = InMemoryDocumentIndex(
documents=new_label_index,
embeddings=new_embeddings,
)
new_document_index.save_pretrained(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered"
)