olamidegoriola's picture
Create FAISS index using omdena qna dataset (#4)
8e6ed3b
raw
history blame contribute delete
No virus
2.32 kB
# -*- coding: utf-8 -*-
"""create_faiss_index.py
"""
import pandas as pd
import numpy as np
import faiss
from sentence_transformers import InputExample, SentenceTransformer
DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv"
TRANSFORMER_MODEL_NAME = "all-distilroberta-v1"
CACHE_DIR_PATH = "../working/cache/"
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
FAISS_INDEX_FILE_PATH = "index.faiss"
def load_data(file_path):
qna_dataset = pd.read_csv(file_path)
qna_dataset["id"] = qna_dataset.index
return qna_dataset.dropna(subset=['Answers']).copy()
def create_input_examples(qna_dataset):
qna_dataset['QNA'] = qna_dataset.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
return qna_dataset.apply(lambda x: InputExample(texts=[x["QNA"]]), axis=1).tolist()
def load_transformer_model(model_name, cache_folder):
transformer_model = SentenceTransformer(model_name, cache_folder=cache_folder)
return transformer_model
def save_transformer_model(transformer_model, model_file):
transformer_model.save(model_file)
def create_faiss_index(transformer_model, qna_dataset):
faiss_embeddings = transformer_model.encode(qna_dataset.Answers.values.tolist())
qna_dataset_indexed = qna_dataset.set_index(["id"], drop=False)
id_index_array = np.array(qna_dataset_indexed.id.values).flatten().astype("int")
normalized_embeddings = faiss_embeddings.copy()
faiss.normalize_L2(normalized_embeddings)
faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(len(faiss_embeddings[0])))
faiss_index.add_with_ids(normalized_embeddings, id_index_array)
return faiss_index
def save_faiss_index(faiss_index, filename):
faiss.write_index(faiss_index, filename)
def load_faiss_index(filename):
return faiss.read_index(filename)
def main():
qna_dataset = load_data(DATA_FILE_PATH)
input_examples = create_input_examples(qna_dataset)
transformer_model = load_transformer_model(TRANSFORMER_MODEL_NAME, CACHE_DIR_PATH)
save_transformer_model(transformer_model, MODEL_SAVE_PATH)
faiss_index = create_faiss_index(transformer_model, qna_dataset)
save_faiss_index(faiss_index, FAISS_INDEX_FILE_PATH)
faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
if __name__ == "__main__":
main()