Kaung Myat Htet commited on
Commit
bde0120
1 Parent(s): b1ac1a0

add conversation history

Browse files
Files changed (2) hide show
  1. app.py +190 -57
  2. requirements.txt +4 -1
app.py CHANGED
@@ -2,14 +2,22 @@ import os
2
  import gradio as gr
3
  from langchain_community.vectorstores import FAISS
4
  from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
 
5
 
 
6
  from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain_core.messages import get_buffer_string
9
  from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
10
 
11
- from langchain_core.prompts import ChatPromptTemplate
12
  from langchain_core.output_parsers import StrOutputParser
 
 
 
 
 
 
13
 
14
 
15
  embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
@@ -20,70 +28,195 @@ db = FAISS.load_local("vms_faiss_index", embedder, allow_dangerous_deserializati
20
  nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")
21
 
22
 
23
- from operator import itemgetter
24
 
 
 
 
 
 
 
 
 
 
25
 
26
- # available models names
27
- # mixtral_8x7b
28
- # llama2_13b
29
- llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
30
 
31
- initial_msg = (
32
- "Hello! I am VMS bot here to help you with your academic issues!"
33
- f"\nHow can I help you?"
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
- context_prompt = ChatPromptTemplate.from_messages([
37
- ('system',
38
- "You are a VMS chatbot, and you are helping students with their academic issues."
39
- "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer."
40
- "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University."
41
- "Do not hallucinate any details, and make sure the knowledge base is not redundant."
42
- "Please say you do not know if you do not know or you cannot find the information needed."
43
- "\n\nQuestion: {question}\n\nContext: {context}"),
44
- ('user', "{question}"
45
- )])
46
-
47
- chain = (
48
- {
49
- 'context': db.as_retriever(search_type="similarity"),
50
- 'question': (lambda x:x)
51
- }
52
- | context_prompt
53
- # | RPrint()
54
- | llm
55
- | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- conv_chain = (
59
- context_prompt
60
- # | RPrint()
61
- | llm
62
- | StrOutputParser()
 
 
63
  )
64
 
65
- def chat_gen(message, history, return_buffer=True):
 
 
66
  buffer = ""
67
 
68
- doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2})
69
- retrieved_docs = doc_retriever.invoke(message)
70
- print(len(retrieved_docs))
71
- print(retrieved_docs)
72
-
73
- if len(retrieved_docs) > 0:
74
- state = {
75
- 'question': message,
76
- 'context': retrieved_docs
77
- }
78
- for token in conv_chain.stream(state):
79
- buffer += token
80
- yield buffer
81
- else:
82
- passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
83
- buffer += passage
84
- yield buffer if return_buffer else passage
85
-
86
-
87
- chatbot = gr.Chatbot(value = [[None, initial_msg]])
88
- iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
89
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  from langchain_community.vectorstores import FAISS
4
  from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
5
+ import pymongo
6
 
7
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
8
  from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain_core.messages import get_buffer_string
11
  from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
12
 
13
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain_core.output_parsers import StrOutputParser
15
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
16
+ from langchain_core.chat_history import BaseChatMessageHistory
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+ from langchain_community.chat_message_histories import ChatMessageHistory
19
+ from langchain_core.runnables.history import RunnableWithMessageHistory
20
+ from langchain_core.messages import HumanMessage
21
 
22
 
23
  embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
 
28
  nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")
29
 
30
 
 
31
 
32
+ def get_mongo_client(mongo_uri):
33
+ """Establish connection to the MongoDB."""
34
+ try:
35
+ client = pymongo.MongoClient(mongo_uri)
36
+ print("Connection to MongoDB successful")
37
+ return client
38
+ except pymongo.errors.ConnectionFailure as e:
39
+ print(f"Connection failed: {e}")
40
+ return None
41
 
42
+ mongo_uri = os.environ.get('MyCluster_MONGO_URI')
43
+ if not mongo_uri:
44
+ print("MONGO_URI not set in environment variables")
 
45
 
46
+ mongo_client = get_mongo_client(mongo_uri)
47
+
48
+ DB_NAME="vms_courses"
49
+ COLLECTION_NAME="courses"
50
+
51
+ db = mongo_client[DB_NAME]
52
+ collection = db[COLLECTION_NAME]
53
+ ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"
54
+
55
+
56
+ vector_search = MongoDBAtlasVectorSearch.from_connection_string(
57
+ mongo_uri,
58
+ DB_NAME + "." + COLLECTION_NAME,
59
+ embedder,
60
+ index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
61
  )
62
 
63
+
64
+ llm = ChatNVIDIA(model="mixtral_8x7b")
65
+
66
+ retriever = vector_search.as_retriever(
67
+ search_type="similarity",
68
+ search_kwargs={"k": 12},
69
+ )
70
+
71
+
72
+
73
+ ### Contextualize question ###
74
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
75
+ which might reference context in the chat history, formulate a standalone question \
76
+ which can be understood without the chat history. Do NOT answer the question, \
77
+ just reformulate it if needed and otherwise return it as is."""
78
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
79
+ [
80
+ ("system", contextualize_q_system_prompt),
81
+ MessagesPlaceholder("chat_history"),
82
+ ("human", "{input}"),
83
+ ]
84
+ )
85
+ history_aware_retriever = create_history_aware_retriever(
86
+ llm, retriever, contextualize_q_prompt
87
+ )
88
+
89
+
90
+ ### Answer question ###
91
+ qa_system_prompt = """You are a VMS assistant for helping students with their academic. \
92
+ Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer. \
93
+ Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University. \
94
+ Do not hallucinate any details, and make sure the knowledge base is not redundant.\
95
+ If you don't know the answer, just say that you don't know. \
96
+
97
+ {context}"""
98
+
99
+ qa_prompt = ChatPromptTemplate.from_messages(
100
+ [
101
+ ("system", qa_system_prompt),
102
+ MessagesPlaceholder("chat_history"),
103
+ ("human", "{input}"),
104
+ ]
105
  )
106
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
107
+
108
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
109
+
110
+
111
+ ### Statefully manage chat history ###
112
+ store = {}
113
+
114
+
115
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
116
+ if session_id not in store:
117
+ store[session_id] = ChatMessageHistory()
118
+ return store[session_id]
119
 
120
+
121
+ conversational_rag_chain = RunnableWithMessageHistory(
122
+ rag_chain,
123
+ get_session_history,
124
+ input_messages_key="input",
125
+ history_messages_key="chat_history",
126
+ output_messages_key="answer",
127
  )
128
 
129
+ c_history = []
130
+
131
+ def chat_gen(message, history):
132
  buffer = ""
133
 
134
+ ai_message = rag_chain.invoke({"input": message, "chat_history": c_history})
135
+ c_history.extend([HumanMessage(content=message), ai_message["answer"]])
136
+ print(c_history)
137
+ yield ai_message["answer"]
138
+
139
+ # for doc in ai_message["context"]:
140
+ # yield doc
141
+
142
+ initial_msg = (
143
+ "Hello! I am VMS bot here to help you with your academic issues!"
144
+ f"\nHow can I help you?"
145
+ )
146
+
147
+
148
+ chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=False)
149
+ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
150
+
151
+ try:
152
+ demo.launch(debug=True, share=True, show_api=False)
153
+ demo.close()
154
+ except Exception as e:
155
+ demo.close()
156
+ print(e)
157
+ raise e
158
+
159
+ # available models names
160
+ # mixtral_8x7b
161
+ # llama2_13b
162
+ # llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
163
+
164
+ # initial_msg = (
165
+ # "Hello! I am VMS bot here to help you with your academic issues!"
166
+ # f"\nHow can I help you?"
167
+ # )
168
+
169
+ # context_prompt = ChatPromptTemplate.from_messages([
170
+ # ('system',
171
+ # "You are a VMS chatbot, and you are helping students with their academic issues."
172
+ # "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer."
173
+ # "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University."
174
+ # "Do not hallucinate any details, and make sure the knowledge base is not redundant."
175
+ # "Please say you do not know if you do not know or you cannot find the information needed."
176
+ # "\n\nQuestion: {question}\n\nContext: {context}"),
177
+ # ('user', "{question}"
178
+ # )])
179
+
180
+ # chain = (
181
+ # {
182
+ # 'context': db.as_retriever(search_type="similarity"),
183
+ # 'question': (lambda x:x)
184
+ # }
185
+ # | context_prompt
186
+ # # | RPrint()
187
+ # | llm
188
+ # | StrOutputParser()
189
+ # )
190
+
191
+ # conv_chain = (
192
+ # context_prompt
193
+ # # | RPrint()
194
+ # | llm
195
+ # | StrOutputParser()
196
+ # )
197
+
198
+ # def chat_gen(message, history, return_buffer=True):
199
+ # buffer = ""
200
+
201
+ # doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2})
202
+ # retrieved_docs = doc_retriever.invoke(message)
203
+ # print(len(retrieved_docs))
204
+ # print(retrieved_docs)
205
+
206
+ # if len(retrieved_docs) > 0:
207
+ # state = {
208
+ # 'question': message,
209
+ # 'context': retrieved_docs
210
+ # }
211
+ # for token in conv_chain.stream(state):
212
+ # buffer += token
213
+ # yield buffer
214
+ # else:
215
+ # passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
216
+ # buffer += passage
217
+ # yield buffer if return_buffer else passage
218
+
219
+
220
+ # chatbot = gr.Chatbot(value = [[None, initial_msg]])
221
+ # iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
222
+ # iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  langchain
2
  langchain-nvidia-ai-endpoints
3
  gradio
4
- faiss-cpu
 
 
 
 
1
  langchain
2
  langchain-nvidia-ai-endpoints
3
  gradio
4
+ faiss-cpu
5
+ pymongo
6
+ llama-index
7
+ llama-index-vector-stores-mongodb