Spaces:
Runtime error
Runtime error
from datasets import Dataset, load_from_disk | |
import faiss | |
import numpy as np | |
from transformers import RagTokenizer, RagSequenceForGeneration | |
def create_and_save_faiss_index(dataset_path, index_path): | |
dataset = load_from_disk(dataset_path) | |
passages = dataset["text"] | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") | |
passage_embeddings = model.get_encoder()( | |
tokenizer(passages, return_tensors="pt", padding=True, truncation=True) | |
).last_hidden_state.mean(dim=1).detach().numpy() | |
index = faiss.IndexFlatL2(passage_embeddings.shape[1]) | |
index.add(passage_embeddings) | |
faiss.write_index(index, index_path) | |
if __name__ == "__main__": | |
dataset_path = "path/to/your/hf_dataset" | |
index_path = "path/to/your/hf_index" | |
create_and_save_faiss_index(dataset_path, index_path) |