Update rag-chat.py
Browse files- rag-chat.py +51 -10
rag-chat.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
|
3 |
from langchain.document_loaders import DirectoryLoader
|
@@ -14,7 +15,32 @@ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
|
14 |
|
15 |
import chainlit as cl
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
model_id = "tiiuae/falcon-7b-instruct"
|
19 |
conv_model = HuggingFaceHub(
|
20 |
huggingfacehub_api_token=os.environ['HF_API_TOKEN'],
|
@@ -22,7 +48,7 @@ conv_model = HuggingFaceHub(
|
|
22 |
model_kwargs={"temperature":0.8,"max_length": 1000}
|
23 |
)
|
24 |
|
25 |
-
# chroma
|
26 |
data_path = "data/html"
|
27 |
embed_model = "all-MiniLM-L6-v2" # Chroma defaults to "sentence-transformers/all-MiniLM-L6-v2"
|
28 |
|
@@ -90,26 +116,32 @@ def prepare_documents(documents):
|
|
90 |
i += 1
|
91 |
return documents
|
92 |
|
|
|
93 |
@cl.on_chat_start
|
94 |
async def on_chat_start():
|
95 |
-
#
|
96 |
embedding_func = SentenceTransformerEmbeddings(model_name=embed_model)
|
97 |
-
|
|
|
98 |
msg = cl.Message(
|
99 |
content="Loading and processing documents. This may take a while...",
|
100 |
disable_human_feedback=True)
|
101 |
await msg.send()
|
102 |
|
|
|
103 |
documents = load_documents(data_path)
|
104 |
documents = prepare_documents(documents)
|
105 |
|
|
|
106 |
docsearch = await cl.make_async(Chroma.from_documents)(
|
107 |
documents,
|
108 |
embedding_func
|
109 |
)
|
110 |
|
|
|
111 |
message_history = ChatMessageHistory()
|
112 |
|
|
|
113 |
memory = ConversationBufferMemory(
|
114 |
memory_key="chat_history",
|
115 |
output_key="answer",
|
@@ -117,6 +149,7 @@ async def on_chat_start():
|
|
117 |
return_messages=True,
|
118 |
)
|
119 |
|
|
|
120 |
chain = ConversationalRetrievalChain.from_llm(
|
121 |
conv_model,
|
122 |
chain_type="stuff",
|
@@ -124,36 +157,44 @@ async def on_chat_start():
|
|
124 |
memory=memory,
|
125 |
return_source_documents=True,
|
126 |
)
|
|
|
|
|
127 |
msg.content = "Ready. You can now ask questions!"
|
128 |
-
|
129 |
await msg.update()
|
|
|
|
|
130 |
cl.user_session.set("chain", chain)
|
131 |
|
132 |
-
|
133 |
@cl.on_message
|
134 |
async def main(message):
|
|
|
135 |
chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
|
136 |
cb = cl.AsyncLangchainCallbackHandler()
|
137 |
|
|
|
138 |
res = await chain.acall(message.content, callbacks=[cb])
|
139 |
|
|
|
140 |
answer = res["answer"]
|
141 |
source_documents = res["source_documents"]
|
142 |
|
143 |
-
text_elements = []
|
144 |
-
|
145 |
-
source_names = set() # Use a set to store unique source names
|
146 |
|
|
|
147 |
for idx, source_doc in enumerate(source_documents):
|
148 |
source_name = source_doc.metadata["source"]
|
149 |
text_elements.append(
|
150 |
cl.Text(content=source_doc.page_content,
|
151 |
name=source_name))
|
152 |
-
source_names.add(source_name) #
|
153 |
|
|
|
154 |
if source_names:
|
155 |
answer += f"\nSources: {', '.join(source_names)}"
|
156 |
else:
|
157 |
answer += "\nNo sources found"
|
158 |
|
|
|
159 |
await cl.Message(content=answer, elements=text_elements).send()
|
|
|
1 |
+
# import all necessary packages
|
2 |
import os
|
3 |
|
4 |
from langchain.document_loaders import DirectoryLoader
|
|
|
15 |
|
16 |
import chainlit as cl
|
17 |
|
18 |
+
from langchain.prompts.chat import (
|
19 |
+
ChatPromptTemplate,
|
20 |
+
SystemMessagePromptTemplate,
|
21 |
+
HumanMessagePromptTemplate,
|
22 |
+
)
|
23 |
+
|
24 |
+
# define prompt template
|
25 |
+
system_template = """Use the following pieces of context to answer the users question.
|
26 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
27 |
+
ALWAYS return a "SOURCES" part in your answer.
|
28 |
+
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
|
29 |
+
And if the user greets with greetings like Hi, hello, How are you, etc reply accordingly as well.
|
30 |
+
Example of your response should be:
|
31 |
+
The answer is foo
|
32 |
+
SOURCES: xyz
|
33 |
+
Begin!
|
34 |
+
----------------
|
35 |
+
{summaries}"""
|
36 |
+
messages = [
|
37 |
+
SystemMessagePromptTemplate.from_template(system_template),
|
38 |
+
HumanMessagePromptTemplate.from_template("{question}"),
|
39 |
+
]
|
40 |
+
prompt = ChatPromptTemplate.from_messages(messages)
|
41 |
+
chain_type_kwargs = {"prompt": prompt}
|
42 |
+
|
43 |
+
# define the llm
|
44 |
model_id = "tiiuae/falcon-7b-instruct"
|
45 |
conv_model = HuggingFaceHub(
|
46 |
huggingfacehub_api_token=os.environ['HF_API_TOKEN'],
|
|
|
48 |
model_kwargs={"temperature":0.8,"max_length": 1000}
|
49 |
)
|
50 |
|
51 |
+
# set up vector db with chroma
|
52 |
data_path = "data/html"
|
53 |
embed_model = "all-MiniLM-L6-v2" # Chroma defaults to "sentence-transformers/all-MiniLM-L6-v2"
|
54 |
|
|
|
116 |
i += 1
|
117 |
return documents
|
118 |
|
119 |
+
# define a function to execute when a chat starts
|
120 |
@cl.on_chat_start
|
121 |
async def on_chat_start():
|
122 |
+
# instantiate the chain for that user session
|
123 |
embedding_func = SentenceTransformerEmbeddings(model_name=embed_model)
|
124 |
+
|
125 |
+
# display a message indicating document loading
|
126 |
msg = cl.Message(
|
127 |
content="Loading and processing documents. This may take a while...",
|
128 |
disable_human_feedback=True)
|
129 |
await msg.send()
|
130 |
|
131 |
+
# load and prepare documents for processing
|
132 |
documents = load_documents(data_path)
|
133 |
documents = prepare_documents(documents)
|
134 |
|
135 |
+
# create a document search object asynchronously
|
136 |
docsearch = await cl.make_async(Chroma.from_documents)(
|
137 |
documents,
|
138 |
embedding_func
|
139 |
)
|
140 |
|
141 |
+
# initialize ChatMessageHistory object to store message history
|
142 |
message_history = ChatMessageHistory()
|
143 |
|
144 |
+
# initialize ConversationBufferMemory object to store conversation history
|
145 |
memory = ConversationBufferMemory(
|
146 |
memory_key="chat_history",
|
147 |
output_key="answer",
|
|
|
149 |
return_messages=True,
|
150 |
)
|
151 |
|
152 |
+
# create a ConversationalRetrievalChain object
|
153 |
chain = ConversationalRetrievalChain.from_llm(
|
154 |
conv_model,
|
155 |
chain_type="stuff",
|
|
|
157 |
memory=memory,
|
158 |
return_source_documents=True,
|
159 |
)
|
160 |
+
|
161 |
+
# indicate readiness for questions
|
162 |
msg.content = "Ready. You can now ask questions!"
|
|
|
163 |
await msg.update()
|
164 |
+
|
165 |
+
# store the chain in the user's session
|
166 |
cl.user_session.set("chain", chain)
|
167 |
|
168 |
+
# define a function to handle messages
|
169 |
@cl.on_message
|
170 |
async def main(message):
|
171 |
+
# retrieve the chain object from the user's session
|
172 |
chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
|
173 |
cb = cl.AsyncLangchainCallbackHandler()
|
174 |
|
175 |
+
# call the chain to process the incoming message
|
176 |
res = await chain.acall(message.content, callbacks=[cb])
|
177 |
|
178 |
+
# retrieve the answer and source documents from the chain's response
|
179 |
answer = res["answer"]
|
180 |
source_documents = res["source_documents"]
|
181 |
|
182 |
+
text_elements = [] # list to store text elements
|
183 |
+
source_names = set() # set to store unique source names
|
|
|
184 |
|
185 |
+
# iterate through source documents and extract relevant information
|
186 |
for idx, source_doc in enumerate(source_documents):
|
187 |
source_name = source_doc.metadata["source"]
|
188 |
text_elements.append(
|
189 |
cl.Text(content=source_doc.page_content,
|
190 |
name=source_name))
|
191 |
+
source_names.add(source_name) # add the source name to the set
|
192 |
|
193 |
+
# append sources information to the answer if available
|
194 |
if source_names:
|
195 |
answer += f"\nSources: {', '.join(source_names)}"
|
196 |
else:
|
197 |
answer += "\nNo sources found"
|
198 |
|
199 |
+
# send the answer along with any extracted text elements
|
200 |
await cl.Message(content=answer, elements=text_elements).send()
|