Skier8402 commited on
Commit
ec9e166
1 Parent(s): e8cc327

Create app.py

Browse files

Add code for mistral implementation.

Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Question Answering with Retrieval QA and LangChain Language Models featuring Chroma Vector Stores
3
+ This script uses the LangChain Language Model API to answer questions using Retrieval QA and Chroma Vector Stores.
4
+ """
5
+
6
+ import os
7
+ import streamlit as st
8
+ from dotenv import load_dotenv
9
+ from PyPDF2 import PdfReader
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
12
+ from langchain.vectorstores import FAISS
13
+ from langchain.chat_models import ChatOpenAI
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain.chains import ConversationalRetrievalChain
16
+ from htmlTemplates import css, bot_template, user_template
17
+ from langchain.llms import HuggingFaceHub
18
+
19
+
20
+ def get_pdf_text(pdf_docs):
21
+ """
22
+ Extract text from a list of PDF documents.
23
+
24
+ Parameters
25
+ ----------
26
+ pdf_docs : list
27
+ List of PDF documents to extract text from.
28
+
29
+ Returns
30
+ -------
31
+ str
32
+ Extracted text from all the PDF documents.
33
+
34
+ """
35
+ text = ""
36
+ for pdf in pdf_docs:
37
+ pdf_reader = PdfReader(pdf)
38
+ for page in pdf_reader.pages:
39
+ text += page.extract_text()
40
+ return text
41
+
42
+
43
+ def get_text_chunks(text):
44
+ """
45
+ Split the input text into chunks.
46
+
47
+ Parameters
48
+ ----------
49
+ text : str
50
+ The input text to be split.
51
+
52
+ Returns
53
+ -------
54
+ list
55
+ List of text chunks.
56
+
57
+ """
58
+ text_splitter = CharacterTextSplitter(
59
+ separator="\n", chunk_size=1500, chunk_overlap=300, length_function=len
60
+ )
61
+ chunks = text_splitter.split_text(text)
62
+ return chunks
63
+
64
+
65
+ def get_vectorstore(text_chunks):
66
+ """
67
+ Generate a vector store from a list of text chunks using HuggingFace BgeEmbeddings.
68
+
69
+ Parameters
70
+ ----------
71
+ text_chunks : list
72
+ List of text chunks to be embedded.
73
+
74
+ Returns
75
+ -------
76
+ FAISS
77
+ A FAISS vector store containing the embeddings of the text chunks.
78
+
79
+ """
80
+ # embeddings = OpenAIEmbeddings()
81
+ model = "BAAI/bge-base-en-v1.5"
82
+ encode_kwargs = {
83
+ "normalize_embeddings": True
84
+ } # set True to compute cosine similarity
85
+ embeddings = HuggingFaceBgeEmbeddings(
86
+ model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
87
+ )
88
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
89
+ return vectorstore
90
+
91
+
92
+ def get_conversation_chain(vectorstore):
93
+ """
94
+ Create a conversational retrieval chain using a vector store and a language model.
95
+
96
+ Parameters
97
+ ----------
98
+ vectorstore : FAISS
99
+ A FAISS vector store containing the embeddings of the text chunks.
100
+
101
+ Returns
102
+ -------
103
+ ConversationalRetrievalChain
104
+ A conversational retrieval chain for generating responses.
105
+
106
+ """
107
+ llm = HuggingFaceHub(
108
+ repo_id="mistralai/Mistral-7B-Instruct-v0.1",
109
+ model_kwargs={"temperature": 0.5, "max_length": 512},
110
+ )
111
+ # llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
112
+
113
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
114
+ conversation_chain = ConversationalRetrievalChain.from_llm(
115
+ llm=llm, retriever=vectorstore.as_retriever(), memory=memory
116
+ )
117
+ return conversation_chain
118
+
119
+
120
+ def handle_userinput(user_question):
121
+ """
122
+ Handle user input and generate a response using the conversational retrieval chain.
123
+
124
+ Parameters
125
+ ----------
126
+ user_question : str
127
+ The user's question.
128
+
129
+ """
130
+ response = st.session_state.conversation({"question": user_question})
131
+ st.session_state.chat_history = response["chat_history"]
132
+
133
+ for i, message in enumerate(st.session_state.chat_history):
134
+ if i % 2 == 0:
135
+ st.write(
136
+ user_template.replace("{{MSG}}", message.content),
137
+ unsafe_allow_html=True,
138
+ )
139
+ else:
140
+ st.write(
141
+ bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True
142
+ )
143
+
144
+
145
+ def main():
146
+
147
+ st.set_page_config(
148
+ page_title="Chat with a Bot that tries to answer questions about multiple PDFs",
149
+ page_icon=":books:",
150
+ )
151
+
152
+ st.markdown("# Chat with a Bot")
153
+ st.markdown("This bot tries to answer questions about multiple PDFs.")
154
+
155
+ st.write(css, unsafe_allow_html=True)
156
+
157
+ # set huggingface hub token in st.text_input widget
158
+ # then hide the input
159
+ huggingface_token = st.text_input("Enter your HuggingFace Hub token", type="password")
160
+ #openai_api_key = st.text_input("Enter your OpenAI API key", type="password")
161
+
162
+ # set this key as an environment variable
163
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = huggingface_token
164
+ #os.environ["OPENAI_API_KEY"] = openai_api_key
165
+
166
+ if "conversation" not in st.session_state:
167
+ st.session_state.conversation = None
168
+ if "chat_history" not in st.session_state:
169
+ st.session_state.chat_history = None
170
+
171
+ st.header("Chat with a Bot 🤖���� that tries to answer questions about multiple PDFs :books:")
172
+ user_question = st.text_input("Ask a question about your documents:")
173
+ if user_question:
174
+ handle_userinput(user_question)
175
+
176
+ with st.sidebar:
177
+ st.subheader("Your documents")
178
+ pdf_docs = st.file_uploader(
179
+ "Upload your PDFs here and click on 'Process'", accept_multiple_files=True
180
+ )
181
+ if st.button("Process"):
182
+ with st.spinner("Processing"):
183
+ # get pdf text
184
+ raw_text = get_pdf_text(pdf_docs)
185
+
186
+ # get the text chunks
187
+ text_chunks = get_text_chunks(raw_text)
188
+
189
+ # create vector store
190
+ vectorstore = get_vectorstore(text_chunks)
191
+
192
+ # create conversation chain
193
+ st.session_state.conversation = get_conversation_chain(vectorstore)
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()