File size: 6,142 Bytes
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328d783
b34c8cb
 
 
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bfd411
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bfd411
 
 
9bcac6c
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bfd411
328d783
8534ef7
b88b2e6
8534ef7
 
 
 
 
 
 
 
 
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8534ef7
b88b2e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# model under development
import os
import streamlit as st

from langchain.document_loaders.csv_loader import CSVLoader
#from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain_text_splitters import CharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
#from langchain_core.prompts import ChatPromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import ConversationalRetrievalChain, LLMChain

#from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
#from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate

# below 3 libraries are for loading remote models
#from transformers import LlamaForCausalLM
#from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM

# adding separator
def add_vertical_space(spaces=1):
    for _ in range(spaces):
        st.sidebar.markdown("---")

# main method
def main():
    # page title
    st.set_page_config(page_title="Chatbot for NHPC", layout="wide")
    st.title("Chatbot for NHPC")
    st.write("##### 🚧 Under development 🚧")

    # faiss db directory
    DB_FAISS_PATH = "vectorstore/db_faiss"
    TEMP_DIR = "temp"

    # embedding model path
    #EMBEDDING_MODEL_PATH = "embeddings/MiniLM-L6-v2"
    
    # creating faiss db direcoty if it doesnot exist already
    if not os.path.exists(TEMP_DIR):
        os.makedirs(TEMP_DIR)

    # uploading csv file
    uploaded_file = st.sidebar.file_uploader("Upload CSV file", type=['csv'], help="Upload a CSV file")

    # adding vertical space
    add_vertical_space(1)

    # creating faiss vectorstore
    if uploaded_file is not None:
        file_path = os.path.join(TEMP_DIR, uploaded_file.name)
        with open(file_path, "wb") as f:
            f.write(uploaded_file.getvalue())
        
        st.write(f"Uploaded file: {uploaded_file.name}")
        st.write("Processing CSV file...")
        st.sidebar.markdown('##### The model may sometime generate excessive or incorrect response.')
        # calling CSVLoader for loading CSV file
        loader = CSVLoader(file_path=file_path, encoding="utf-8", csv_args={'delimiter': ','})

        # loading the CSV file data
        data = loader.load()

        # creating embeddings using huggingface
        #embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

        # loading remote embedding model
        embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
        
        # creating chunks from CSV file
        #text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=50)
        text_splitter = SemanticChunker(embeddings, breakpoint_threshold_type="interquartile")
        #text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=100)
        text_chunks = text_splitter.split_documents(data)

        # chunks message to output
        st.write(f"Total text chunks: {len(text_chunks)}")
        st.write("---")
        
        # creating vectorstore from the text chunks
        docsearch = FAISS.from_documents(text_chunks, embeddings)

        # saving the vector store to local directory
        docsearch.save_local(DB_FAISS_PATH)

        # loading remote llama model
        #llm = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        #llm = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-2b-it")

        token = os.environ["HF_TOKEN"]
        llm = AutoModelForCausalLM.from_pretrained(
        "google/gemma-7b-it",
        # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        torch_dtype=torch.float16,
        token=token,
        )

        
        # custom prompt
        custom_template="""
You are a smart personal assistant and your task is to provide the answer of the given question based only on the given context. \n
If you can't find the answer in the context, just say that "I don't know, please look up the policy." and don't try to make up an answer. \n\n

Please, give the answer in plain english and don't repeat your answer and don't mention that you found the answer form the context and don't mention that the answer can be found in the context. \n
        
Question: "{question}" \n\n
        
Context: "{context}" \n\n
        
Helpful Answer:
"""
        
        QA_PROMPT = PromptTemplate(template=custom_template,input_variables=["question", "context"])

        # main llm chain
        qa = ConversationalRetrievalChain.from_llm(llm,
                                                   #chain_type = "stuff",
                                                   chain_type = "stuff",
                                                   #verbose=True,
                                                   #retriever=docsearch.as_retriever()
                                                   retriever=docsearch.as_retriever(search_kwargs = {"k" : 4, "search_type" : "similarity"}),
                                                   combine_docs_chain_kwargs={"prompt": QA_PROMPT}
                                                   #retriever=docsearch.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.1})
                                                   #memory=memory
                                                   )
        
        # taking question from user
        #st.write("### Enter your query:")
        query = st.chat_input("Ask a question to the chatbot.")
        
        if query:
            st.write("#### Query: "+query)
            with st.spinner("Processing your question..."):
                chat_history = []
                result = qa({"question": query, "chat_history": chat_history})
                #st.write("---")
                #st.write("### Response:")
                #st.write("#### Query: "+query)
                st.write(f"> {result['answer']}")

        os.remove(file_path)

if __name__ == "__main__":
    main()