smkerr commited on
Commit
0e07c97
1 Parent(s): 30ae45e

Update rag-chat.py

Browse files
Files changed (1) hide show
  1. rag-chat.py +51 -10
rag-chat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
 
3
  from langchain.document_loaders import DirectoryLoader
@@ -14,7 +15,32 @@ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
14
 
15
  import chainlit as cl
16
 
17
- # llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model_id = "tiiuae/falcon-7b-instruct"
19
  conv_model = HuggingFaceHub(
20
  huggingfacehub_api_token=os.environ['HF_API_TOKEN'],
@@ -22,7 +48,7 @@ conv_model = HuggingFaceHub(
22
  model_kwargs={"temperature":0.8,"max_length": 1000}
23
  )
24
 
25
- # chroma
26
  data_path = "data/html"
27
  embed_model = "all-MiniLM-L6-v2" # Chroma defaults to "sentence-transformers/all-MiniLM-L6-v2"
28
 
@@ -90,26 +116,32 @@ def prepare_documents(documents):
90
  i += 1
91
  return documents
92
 
 
93
  @cl.on_chat_start
94
  async def on_chat_start():
95
- # Instantiate the chain for that user session
96
  embedding_func = SentenceTransformerEmbeddings(model_name=embed_model)
97
-
 
98
  msg = cl.Message(
99
  content="Loading and processing documents. This may take a while...",
100
  disable_human_feedback=True)
101
  await msg.send()
102
 
 
103
  documents = load_documents(data_path)
104
  documents = prepare_documents(documents)
105
 
 
106
  docsearch = await cl.make_async(Chroma.from_documents)(
107
  documents,
108
  embedding_func
109
  )
110
 
 
111
  message_history = ChatMessageHistory()
112
 
 
113
  memory = ConversationBufferMemory(
114
  memory_key="chat_history",
115
  output_key="answer",
@@ -117,6 +149,7 @@ async def on_chat_start():
117
  return_messages=True,
118
  )
119
 
 
120
  chain = ConversationalRetrievalChain.from_llm(
121
  conv_model,
122
  chain_type="stuff",
@@ -124,36 +157,44 @@ async def on_chat_start():
124
  memory=memory,
125
  return_source_documents=True,
126
  )
 
 
127
  msg.content = "Ready. You can now ask questions!"
128
-
129
  await msg.update()
 
 
130
  cl.user_session.set("chain", chain)
131
 
132
-
133
  @cl.on_message
134
  async def main(message):
 
135
  chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
136
  cb = cl.AsyncLangchainCallbackHandler()
137
 
 
138
  res = await chain.acall(message.content, callbacks=[cb])
139
 
 
140
  answer = res["answer"]
141
  source_documents = res["source_documents"]
142
 
143
- text_elements = []
144
-
145
- source_names = set() # Use a set to store unique source names
146
 
 
147
  for idx, source_doc in enumerate(source_documents):
148
  source_name = source_doc.metadata["source"]
149
  text_elements.append(
150
  cl.Text(content=source_doc.page_content,
151
  name=source_name))
152
- source_names.add(source_name) # Add the source name to the set
153
 
 
154
  if source_names:
155
  answer += f"\nSources: {', '.join(source_names)}"
156
  else:
157
  answer += "\nNo sources found"
158
 
 
159
  await cl.Message(content=answer, elements=text_elements).send()
 
1
+ # import all necessary packages
2
  import os
3
 
4
  from langchain.document_loaders import DirectoryLoader
 
15
 
16
  import chainlit as cl
17
 
18
+ from langchain.prompts.chat import (
19
+ ChatPromptTemplate,
20
+ SystemMessagePromptTemplate,
21
+ HumanMessagePromptTemplate,
22
+ )
23
+
24
+ # define prompt template
25
+ system_template = """Use the following pieces of context to answer the users question.
26
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
27
+ ALWAYS return a "SOURCES" part in your answer.
28
+ The "SOURCES" part should be a reference to the source of the document from which you got your answer.
29
+ And if the user greets with greetings like Hi, hello, How are you, etc reply accordingly as well.
30
+ Example of your response should be:
31
+ The answer is foo
32
+ SOURCES: xyz
33
+ Begin!
34
+ ----------------
35
+ {summaries}"""
36
+ messages = [
37
+ SystemMessagePromptTemplate.from_template(system_template),
38
+ HumanMessagePromptTemplate.from_template("{question}"),
39
+ ]
40
+ prompt = ChatPromptTemplate.from_messages(messages)
41
+ chain_type_kwargs = {"prompt": prompt}
42
+
43
+ # define the llm
44
  model_id = "tiiuae/falcon-7b-instruct"
45
  conv_model = HuggingFaceHub(
46
  huggingfacehub_api_token=os.environ['HF_API_TOKEN'],
 
48
  model_kwargs={"temperature":0.8,"max_length": 1000}
49
  )
50
 
51
+ # set up vector db with chroma
52
  data_path = "data/html"
53
  embed_model = "all-MiniLM-L6-v2" # Chroma defaults to "sentence-transformers/all-MiniLM-L6-v2"
54
 
 
116
  i += 1
117
  return documents
118
 
119
+ # define a function to execute when a chat starts
120
  @cl.on_chat_start
121
  async def on_chat_start():
122
+ # instantiate the chain for that user session
123
  embedding_func = SentenceTransformerEmbeddings(model_name=embed_model)
124
+
125
+ # display a message indicating document loading
126
  msg = cl.Message(
127
  content="Loading and processing documents. This may take a while...",
128
  disable_human_feedback=True)
129
  await msg.send()
130
 
131
+ # load and prepare documents for processing
132
  documents = load_documents(data_path)
133
  documents = prepare_documents(documents)
134
 
135
+ # create a document search object asynchronously
136
  docsearch = await cl.make_async(Chroma.from_documents)(
137
  documents,
138
  embedding_func
139
  )
140
 
141
+ # initialize ChatMessageHistory object to store message history
142
  message_history = ChatMessageHistory()
143
 
144
+ # initialize ConversationBufferMemory object to store conversation history
145
  memory = ConversationBufferMemory(
146
  memory_key="chat_history",
147
  output_key="answer",
 
149
  return_messages=True,
150
  )
151
 
152
+ # create a ConversationalRetrievalChain object
153
  chain = ConversationalRetrievalChain.from_llm(
154
  conv_model,
155
  chain_type="stuff",
 
157
  memory=memory,
158
  return_source_documents=True,
159
  )
160
+
161
+ # indicate readiness for questions
162
  msg.content = "Ready. You can now ask questions!"
 
163
  await msg.update()
164
+
165
+ # store the chain in the user's session
166
  cl.user_session.set("chain", chain)
167
 
168
+ # define a function to handle messages
169
  @cl.on_message
170
  async def main(message):
171
+ # retrieve the chain object from the user's session
172
  chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
173
  cb = cl.AsyncLangchainCallbackHandler()
174
 
175
+ # call the chain to process the incoming message
176
  res = await chain.acall(message.content, callbacks=[cb])
177
 
178
+ # retrieve the answer and source documents from the chain's response
179
  answer = res["answer"]
180
  source_documents = res["source_documents"]
181
 
182
+ text_elements = [] # list to store text elements
183
+ source_names = set() # set to store unique source names
 
184
 
185
+ # iterate through source documents and extract relevant information
186
  for idx, source_doc in enumerate(source_documents):
187
  source_name = source_doc.metadata["source"]
188
  text_elements.append(
189
  cl.Text(content=source_doc.page_content,
190
  name=source_name))
191
+ source_names.add(source_name) # add the source name to the set
192
 
193
+ # append sources information to the answer if available
194
  if source_names:
195
  answer += f"\nSources: {', '.join(source_names)}"
196
  else:
197
  answer += "\nNo sources found"
198
 
199
+ # send the answer along with any extracted text elements
200
  await cl.Message(content=answer, elements=text_elements).send()