File size: 5,455 Bytes
db36f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78bbebb
db36f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from dotenv import load_dotenv, find_dotenv
from langchain.chains import LLMChain
import streamlit as st
from decouple import config
from langchain.llms import OpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.evaluation.qa import QAGenerateChain
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import CSVLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.prompts import ChatPromptTemplate
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
import time
from htmlTemplates import css, bot_template, user_template
from pathlib import Path
import pathlib
import platform
plt = platform.system()
if plt == 'Linux':
    pathlib.WindowsPath = pathlib.PosixPath

_ = load_dotenv(find_dotenv())  # read local .env file


def timeit(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Start time
        result = func(*args, **kwargs)  # Function execution
        end_time = time.time()  # End time
        print(
            f"Function {func.__name__} took {end_time - start_time} seconds to execute.")
        return result
    return wrapper


@timeit
def get_llm():
    return OpenAI(temperature=0.1)


@timeit
def get_memory():
    return ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True
    )


@timeit
def generate_response(question, vectordb, llm, memory, chat_history):
    template = """Use the provided context to answer the user's question.
    you are honest petroleum engineer specialist in hydraulic fracture stimulation and reservoir engineering.
    If you don't know the answer, respond with "Sorry Sir, I do not know".
    Context: {context}
    Question: {question}
    Answer:
    """

    prompt = PromptTemplate(
        template=template,
        input_variables=[ 'question','context'])

    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=vectordb.as_retriever(search_type="mmr", k=5, fetch_k=10),
        memory=memory,
        combine_docs_chain_kwargs={"prompt": prompt}
    )

    handle_userinput(
        (qa_chain({"question": question, "chat_history": chat_history})))


@timeit
def create_embeding_function():
    # embedding_func_all_mpnet_base_v2 = SentenceTransformerEmbeddings(
    #     model_name="all-mpnet-base-v2")
    # # embedding_func_all_MiniLM_L6_v2 = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    # embedding_func_jina_embeddings_v2_base_en = SentenceTransformerEmbeddings(
    #     model_name="jinaai/jina-embeddings-v2-base-en"
    # )
    # embedding_func_jina_embeddings_v2_small_en = SentenceTransformerEmbeddings(
    #     model_name="jinaai/jina-embeddings-v2-small-en"
    # )
    embedding_func_jgte_large = SentenceTransformerEmbeddings(
    model_name="thenlper/gte-large"
    )
    return embedding_func_jgte_large


@timeit
def get_vector_db(embedding_function):
    vector_db = Chroma(persist_directory=str(Path('gte_large')),
                       embedding_function=embedding_function)
    return vector_db


def handle_userinput(user_question):
    response = user_question
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []

    st.session_state.chat_history = response['chat_history']

    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            st.write(user_template.replace(
                "{{MSG}}", message.content), unsafe_allow_html=True)
        else:
            st.write(bot_template.replace(
                "{{MSG}}", message.content), unsafe_allow_html=True)


if __name__ == "__main__":

    st.set_page_config(
        page_title="Hydraulic Fracture Stimulation Chat", page_icon=":books:")
    st.write(css, unsafe_allow_html=True)
    st.title("Hydraulic Fracture Stimulation Chat")
    st.write(
        "This is a chatbot that can answer questions related to petroleum engineering specially in hydraulic fracture stimulation.")

    # get embeding function
    embeding_function = create_embeding_function()
    # get vector db
    vector_db = get_vector_db(embeding_function)
     # get llm
    llm = get_llm()

    # get memory
    if 'memory' not in st.session_state:
        st.session_state['memory'] = get_memory()
    memory = st.session_state['memory']

    # chat history
    chat_history = []

    prompt_question = st.chat_input("Please ask a question:")
    if prompt_question:
        generate_response(question=prompt_question, vectordb=vector_db,
                          llm=llm, memory=memory, chat_history=chat_history)