|
|
import faiss |
|
|
import numpy as np |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from torch.nn import DataParallel |
|
|
from transformers import AutoTokenizer, AutoModel, T5EncoderModel |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from multiprocessing import Pool |
|
|
|
|
|
|
|
|
print("Loading data...") |
|
|
with open("merged_triple_processed_new_withID.json", "r") as fi: |
|
|
data = json.load(fi) |
|
|
|
|
|
sentences = [_['contents'] for _ in data] |
|
|
print(f"Chunks nums: {len(sentences)}") |
|
|
|
|
|
model_path = 'sentence-transformers/gtr-t5-large' |
|
|
|
|
|
def encode_sentences_on_gpu(params): |
|
|
sentences_chunk, device_id = params |
|
|
device = torch.device(f'cuda:{device_id}') |
|
|
model = SentenceTransformer(model_path, device=device) |
|
|
embeddings = model.encode( |
|
|
sentences_chunk, |
|
|
batch_size=512, |
|
|
show_progress_bar=True, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
desc=f'GPU {device_id} encoding' |
|
|
) |
|
|
return embeddings |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
|
print(f"Number of GPUs: {num_gpus}") |
|
|
|
|
|
sentences_chunks = np.array_split(sentences, num_gpus) |
|
|
params = [(sentences_chunks[i], i) for i in range(num_gpus)] |
|
|
|
|
|
print("Starting encoding process...") |
|
|
with Pool(processes=num_gpus) as pool: |
|
|
embeddings_list = list(tqdm( |
|
|
pool.imap(encode_sentences_on_gpu, params), |
|
|
total=num_gpus, |
|
|
desc='Overall progress' |
|
|
)) |
|
|
|
|
|
print("Concatenating embeddings...") |
|
|
sentence_embeddings = np.concatenate(embeddings_list, axis=0) |
|
|
|
|
|
|
|
|
print("Creating FAISS index...") |
|
|
dim = sentence_embeddings.shape[1] |
|
|
faiss_index = faiss.IndexFlatIP(dim) |
|
|
|
|
|
|
|
|
print("Adding embeddings to FAISS index...") |
|
|
faiss_index.add(sentence_embeddings) |
|
|
|
|
|
|
|
|
print("Saving FAISS index...") |
|
|
faiss_index_file = 'faiss_index.bin' |
|
|
faiss.write_index(faiss_index, faiss_index_file) |
|
|
print(f"FAISS index saved to {faiss_index_file}") |
|
|
|
|
|
print("Saving embeddings...") |
|
|
embeddings_file = 'document_embeddings.npy' |
|
|
np.save(embeddings_file, sentence_embeddings) |
|
|
print(f"Document embeddings saved to {embeddings_file}") |