File size: 2,315 Bytes
8e6ed3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""create_faiss_index.py
"""

import pandas as pd
import numpy as np
import faiss
from sentence_transformers import InputExample, SentenceTransformer

DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv"
TRANSFORMER_MODEL_NAME = "all-distilroberta-v1"
CACHE_DIR_PATH = "../working/cache/"
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
FAISS_INDEX_FILE_PATH = "index.faiss"

def load_data(file_path):
    qna_dataset = pd.read_csv(file_path)
    qna_dataset["id"] = qna_dataset.index
    return qna_dataset.dropna(subset=['Answers']).copy()

def create_input_examples(qna_dataset):
    qna_dataset['QNA'] = qna_dataset.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
    return qna_dataset.apply(lambda x: InputExample(texts=[x["QNA"]]), axis=1).tolist()

def load_transformer_model(model_name, cache_folder):
    transformer_model = SentenceTransformer(model_name, cache_folder=cache_folder)
    return transformer_model

def save_transformer_model(transformer_model, model_file):
    transformer_model.save(model_file)
    
def create_faiss_index(transformer_model, qna_dataset):
    faiss_embeddings = transformer_model.encode(qna_dataset.Answers.values.tolist())
    qna_dataset_indexed = qna_dataset.set_index(["id"], drop=False)
    id_index_array = np.array(qna_dataset_indexed.id.values).flatten().astype("int")
    normalized_embeddings = faiss_embeddings.copy()
    faiss.normalize_L2(normalized_embeddings)
    faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(len(faiss_embeddings[0])))
    faiss_index.add_with_ids(normalized_embeddings, id_index_array)
    return faiss_index

def save_faiss_index(faiss_index, filename):
    faiss.write_index(faiss_index, filename)

def load_faiss_index(filename):
    return faiss.read_index(filename)

def main():
    qna_dataset = load_data(DATA_FILE_PATH)
    input_examples = create_input_examples(qna_dataset)
    transformer_model = load_transformer_model(TRANSFORMER_MODEL_NAME, CACHE_DIR_PATH)
    save_transformer_model(transformer_model, MODEL_SAVE_PATH)
    faiss_index = create_faiss_index(transformer_model, qna_dataset)
    save_faiss_index(faiss_index, FAISS_INDEX_FILE_PATH)
    faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)

if __name__ == "__main__":
    main()