|
|
import faiss |
|
|
import numpy as np |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from torch.nn import DataParallel |
|
|
from transformers import AutoTokenizer, AutoModel, T5EncoderModel |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from multiprocessing import Pool |
|
|
import time |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with open("merged_triple_processed_new_withID.json", "r") as fi: |
|
|
data = json.load(fi) |
|
|
|
|
|
sentences = [_['contents'] for _ in data] |
|
|
print("Chunks nums: ", len(sentences)) |
|
|
|
|
|
|
|
|
|
|
|
model_path = 'facebook/contriever' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModel.from_pretrained(model_path) |
|
|
|
|
|
model = DataParallel(model) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model.to(device) |
|
|
|
|
|
batch_size = 1024 |
|
|
|
|
|
def mean_pooling(token_embeddings, mask): |
|
|
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) |
|
|
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] |
|
|
return sentence_embeddings |
|
|
|
|
|
def process_in_batches(sentences, batch_size): |
|
|
sentence_embeddings_list = [] |
|
|
for i in tqdm(range(0, len(sentences), batch_size)): |
|
|
batch_sentences = sentences[i:i + batch_size] |
|
|
encoded_input = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt').to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
model_output = model(**encoded_input) |
|
|
batch_sentence_embeddings = mean_pooling(model_output[0], encoded_input['attention_mask']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence_embeddings_list.append(batch_sentence_embeddings.cpu()) |
|
|
|
|
|
|
|
|
sentence_embeddings = torch.cat(sentence_embeddings_list, dim=0) |
|
|
return sentence_embeddings |
|
|
|
|
|
|
|
|
sentence_embeddings = process_in_batches(sentences, batch_size) |
|
|
|
|
|
sentence_embeddings = sentence_embeddings.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
dim = sentence_embeddings.shape[1] |
|
|
faiss_index = faiss.IndexFlatIP(dim) |
|
|
|
|
|
faiss_index.add(sentence_embeddings) |
|
|
|
|
|
faiss_index_file = 'faiss_index.bin' |
|
|
faiss.write_index(faiss_index, faiss_index_file) |
|
|
print(f"FAISS index saved to {faiss_index_file}") |
|
|
|
|
|
embeddings_file = 'document_embeddings.npy' |
|
|
np.save(embeddings_file, sentence_embeddings) |
|
|
print(f"Document embeddings saved to {embeddings_file}") |
|
|
|
|
|
end_time = time.time() |
|
|
execution_time_hours = (end_time - start_time) / 3600 |
|
|
print(f"Total execution time: {execution_time_hours:.2f} hours") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|