smkerr commited on
Commit
6fa8438
1 Parent(s): 0a8e1e5

Upload rag-chat.py

Browse files
Files changed (1) hide show
  1. rag-chat.py +159 -0
rag-chat.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain.document_loaders import DirectoryLoader
4
+ from langchain.document_loaders import BSHTMLLoader
5
+ from bs4 import SoupStrainer
6
+ import re
7
+
8
+ from langchain import HuggingFaceHub, PromptTemplate, LLMChain
9
+ from langchain.embeddings import SentenceTransformerEmbeddings
10
+ from langchain.vectorstores import Chroma
11
+
12
+ from langchain.chains import ConversationalRetrievalChain
13
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
14
+
15
+ import chainlit as cl
16
+
17
+ # llm
18
+ model_id = "tiiuae/falcon-7b-instruct"
19
+ conv_model = HuggingFaceHub(
20
+ huggingfacehub_api_token=os.environ['HF_API_TOKEN'],
21
+ repo_id=model_id,
22
+ model_kwargs={"temperature":0.8,"max_length": 1000}
23
+ )
24
+
25
+ # chroma
26
+ data_path = "data/html"
27
+ embed_model = "all-MiniLM-L6-v2" # Chroma defaults to "sentence-transformers/all-MiniLM-L6-v2"
28
+
29
+ # load documents
30
+ def load_documents(directory):
31
+
32
+ # define Beautiful Soup key word args
33
+ bs_kwargs = {
34
+ "features": "html.parser",
35
+ "parse_only": SoupStrainer("p") # only include relevant text
36
+ }
37
+
38
+ # define Loader key word args
39
+ loader_kwargs = {
40
+ "open_encoding": "utf-8",
41
+ "bs_kwargs": bs_kwargs
42
+ }
43
+
44
+ # define Loader
45
+ loader = DirectoryLoader(
46
+ path=directory,
47
+ glob="*.html",
48
+ loader_cls=BSHTMLLoader,
49
+ loader_kwargs=loader_kwargs
50
+ )
51
+
52
+ documents = loader.load()
53
+ return documents
54
+
55
+
56
+ # prepare documents
57
+ def prepare_documents(documents):
58
+ for doc in documents:
59
+ doc.page_content = doc.page_content.replace("\n", " ").replace("\t", " ")
60
+ doc.page_content = re.sub("\\s+", " ", doc.page_content)
61
+
62
+ # define Beautiful Soup key word args
63
+ bs_kwargs = {
64
+ "features": "html.parser",
65
+ "parse_only": SoupStrainer("title") # only include relevant text
66
+ }
67
+
68
+ # define Loader key word args
69
+ loader_kwargs = {
70
+ "open_encoding": "utf-8",
71
+ "bs_kwargs": bs_kwargs
72
+ }
73
+
74
+ loader = DirectoryLoader(
75
+ path=data_path,
76
+ glob="*.html",
77
+ loader_cls=BSHTMLLoader,
78
+ loader_kwargs=loader_kwargs
79
+ )
80
+
81
+ document_sources = loader.load()
82
+
83
+ # convert source metadata into a list
84
+ source_list = [doc.metadata["title"] for doc in document_sources]
85
+
86
+ # update source metadata
87
+ i = 0
88
+ for doc in documents:
89
+ doc.metadata["source"] = " ".join(["FAR", source_list[i]])
90
+ i += 1
91
+ return documents
92
+
93
+ @cl.on_chat_start
94
+ async def on_chat_start():
95
+ # Instantiate the chain for that user session
96
+ embedding_func = SentenceTransformerEmbeddings(model_name=embed_model)
97
+
98
+ msg = cl.Message(
99
+ content="Loading and processing documents. This may take a while...",
100
+ disable_human_feedback=True)
101
+ await msg.send()
102
+
103
+ documents = load_documents(data_path)
104
+ documents = prepare_documents(documents)
105
+
106
+ docsearch = await cl.make_async(Chroma.from_documents)(
107
+ documents,
108
+ embedding_func
109
+ )
110
+
111
+ message_history = ChatMessageHistory()
112
+
113
+ memory = ConversationBufferMemory(
114
+ memory_key="chat_history",
115
+ output_key="answer",
116
+ chat_memory=message_history,
117
+ return_messages=True,
118
+ )
119
+
120
+ chain = ConversationalRetrievalChain.from_llm(
121
+ conv_model,
122
+ chain_type="stuff",
123
+ retriever=docsearch.as_retriever(),
124
+ memory=memory,
125
+ return_source_documents=True,
126
+ )
127
+ msg.content = "Ready. You can now ask questions!"
128
+
129
+ await msg.update()
130
+ cl.user_session.set("chain", chain)
131
+
132
+
133
+ @cl.on_message
134
+ async def main(message):
135
+ chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
136
+ cb = cl.AsyncLangchainCallbackHandler()
137
+
138
+ res = await chain.acall(message.content, callbacks=[cb])
139
+
140
+ answer = res["answer"]
141
+ source_documents = res["source_documents"]
142
+
143
+ text_elements = []
144
+
145
+ source_names = set() # Use a set to store unique source names
146
+
147
+ for idx, source_doc in enumerate(source_documents):
148
+ source_name = source_doc.metadata["source"]
149
+ text_elements.append(
150
+ cl.Text(content=source_doc.page_content,
151
+ name=source_name))
152
+ source_names.add(source_name) # Add the source name to the set
153
+
154
+ if source_names:
155
+ answer += f"\nSources: {', '.join(source_names)}"
156
+ else:
157
+ answer += "\nNo sources found"
158
+
159
+ await cl.Message(content=answer, elements=text_elements).send()