File size: 3,474 Bytes
84ddfaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from langchain.vectorstores import FAISS
from langchain.llms import GooglePalm
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import TextLoader
from langchain.document_loaders import Docx2txtLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
from dotenv import load_dotenv

vector_index_path = "assets/vectordb/faiss_index"


def load_env_variables():
    load_dotenv()  # take environment variables from .env


def create_vector_db(filename, instructor_embeddings):

    if filename.endswith(".pdf"):    
        loader = PyPDFLoader(file_path=filename)
    elif filename.endswith(".doc") or filename.endswith(".docx"):
        loader = Docx2txtLoader(filename)
    elif filename.endswith("txt") or filename.endswith("TXT"):
        loader = TextLoader(filename)

    # Split documents
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=10)
    splits = text_splitter.split_documents(loader.load())

    # data = loader.load()

    # Create a FAISS instance for vector database from 'data'
    vectordb = FAISS.from_documents(documents=splits,
                                    embedding=instructor_embeddings)

    # Save vector database locally
    vectordb.save_local(vector_index_path)


def get_qa_chain(instructor_embeddings, llm):

    # Load the vector database from the local folder
    vectordb = FAISS.load_local(vector_index_path, instructor_embeddings)

    # Create a retriever for querying the vector database
    retriever = vectordb.as_retriever(search_type="similarity")

    prompt_template = """
    You are a question answer agent and you must strictly follow below prompt template.
    Given the following context and a question, generate an answer based on this context only.
    In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
    Keep answers brief and well-structured. Do not give one word answers.
    If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.

    CONTEXT: {context}

    QUESTION: {question}"""

    PROMPT = PromptTemplate(
        template=prompt_template, input_variables=["context", "question"]
    )

    chain = RetrievalQA.from_chain_type(llm=llm,
                                        chain_type="stuff",  # or map-reduce
                                        retriever=retriever,
                                        input_key="query",
                                        return_source_documents=True,  # return source document from the vector db
                                        chain_type_kwargs={"prompt": PROMPT},
                                        verbose=True)

    return chain


def load_model_params():

    load_env_variables()
    # Create Google Palm LLM model
    llm = GooglePalm(google_api_key=os.environ["GOOGLE_API_KEY"], temperature=0.1)
    # # Initialize instructor embeddings using the Hugging Face model
    instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")

    return llm, instructor_embeddings


def document_parser(instructor_embeddings, llm):

    chain = get_qa_chain(instructor_embeddings=instructor_embeddings, llm=llm)

    return chain