Asaad Almutareb commited on
Commit
93457a9
1 Parent(s): d4bbccc

added caching with sqlite

Browse files

added different retrievers

Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +84 -45
.gitignore CHANGED
@@ -165,4 +165,5 @@ cython_debug/
165
  *.pickle
166
  chroma_db/*
167
  bin
168
- obj
 
 
165
  *.pickle
166
  chroma_db/*
167
  bin
168
+ obj
169
+ .langchain.sqlite
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # gradio
2
  import gradio as gr
3
  #import random
4
- #import time
5
  #boto3 for S3 access
6
  import boto3
7
  from botocore import UNSIGNED
@@ -15,25 +15,36 @@ from langchain.llms import HuggingFaceHub
15
  from langchain.embeddings import HuggingFaceHubEmbeddings
16
  # vectorestore
17
  from langchain.vectorstores import Chroma
18
- from langchain.vectorstores import FAISS
19
  # retrieval chain
20
- from langchain.chains import RetrievalQA
21
  from langchain.chains import RetrievalQAWithSourcesChain
22
  # prompt template
23
  from langchain.prompts import PromptTemplate
24
  from langchain.memory import ConversationBufferMemory
25
  # logging
26
  import logging
27
- import zipfile
28
- #contextual retriever
29
- from langchain.retrievers import ContextualCompressionRetriever
30
- from langchain.retrievers.document_compressors import LLMChainExtractor
31
- from langchain.retrievers.document_compressors import EmbeddingsFilter
32
- from langchain.retrievers.multi_query import MultiQueryRetriever
33
- # streaming
34
- #from threading import Thread
35
- #from transformers import TextIteratorStreamer
36
-
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # load .env variables
39
  config = load_dotenv(".env")
@@ -43,22 +54,34 @@ AWS_S3_FILE=os.getenv('AWS_S3_FILE')
43
  VS_DESTINATION=os.getenv('VS_DESTINATION')
44
 
45
  # initialize Model config
46
- model_id = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={
47
- "temperature":0.1,
48
  "max_new_tokens":1024,
49
  "repetition_penalty":1.2,
50
- "streaming": True,
51
- "return_full_text":True
52
  })
53
 
54
- model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
 
55
  embeddings = HuggingFaceHubEmbeddings(repo_id=model_name)
56
 
 
 
 
 
 
 
 
 
 
 
57
  # retrieve vectorsrore
58
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
59
 
60
  ## Chroma DB
61
  s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
 
62
  db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
63
  db.get()
64
 
@@ -70,12 +93,18 @@ db.get()
70
  # FAISS_INDEX_PATH='./chroma_db/faiss_db_ray'
71
  # db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
72
 
73
- retriever = db.as_retriever(search_type = "mmr")#, search_kwargs={'k': 5, 'fetch_k': 25})
 
 
74
 
75
- compressor = LLMChainExtractor.from_llm(model_id)
76
- compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
77
- # embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
78
- # compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=retriever)
 
 
 
 
79
 
80
  global qa
81
  template = """
@@ -101,16 +130,18 @@ memory = ConversationBufferMemory(memory_key="history", input_key="question")
101
 
102
  # logging for the chain
103
  logging.basicConfig()
104
- logging.getLogger("langchain.chains").setLevel(logging.INFO)
 
 
105
 
106
 
107
- # qa = RetrievalQA.from_chain_type(llm=model_id, chain_type="stuff", retriever=compression_retriever, verbose=True, return_source_documents=True, chain_type_kwargs={
108
  # "verbose": True,
109
  # "memory": memory,
110
  # "prompt": prompt
111
  # }
112
  # )
113
- qa = RetrievalQAWithSourcesChain.from_chain_type(llm=model_id, retriever=compression_retriever, verbose=True, chain_type_kwargs={
114
  "verbose": True,
115
  "memory": memory,
116
  "prompt": prompt,
@@ -118,8 +149,11 @@ qa = RetrievalQAWithSourcesChain.from_chain_type(llm=model_id, retriever=compres
118
  }
119
  )
120
 
121
- def pretty_print_docs(docs):
122
- print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
 
 
 
123
 
124
  def add_text(history, text):
125
  history = history + [(text, None)]
@@ -127,20 +161,18 @@ def add_text(history, text):
127
 
128
  def bot(history):
129
  response = infer(history[-1][0], history)
130
- print(*response)
131
- print(*memory)
132
- sources = [doc.metadata.get("source") for doc in response['sources']]
133
  src_list = '\n'.join(sources)
134
  print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
135
- #sources = f"`Sources:`\n\n' + response['sources']"
136
 
137
- #history[-1][1] = ""
138
- #for character in response['result']: #print_this:
139
- # history[-1][1] += character
140
- # time.sleep(0.05)
141
- # yield history
142
- history[-1][1] = response['answer']
143
- return history #, sources
 
144
 
145
  def infer(question, history):
146
  query = question
@@ -152,23 +184,30 @@ css="""
152
  """
153
 
154
  title = """
155
- <div style="text-align: center;max-width: 700px;">
156
  <h1>Chat with your Documentation</h1>
157
- <p style="text-align: center;">Chat with Documentation, <br />
158
- when everything is ready, you can start asking questions about the docu ;)</p>
159
  </div>
160
  """
161
 
 
 
162
  with gr.Blocks(css=css) as demo:
163
- with gr.Column(elem_id="col-container"):
164
  gr.HTML(title)
165
  chatbot = gr.Chatbot([], elem_id="chatbot")
166
- clear = gr.Button("Clear")
 
 
167
  with gr.Row():
168
  question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
 
 
 
169
  question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
170
  bot, chatbot, chatbot
171
  )
172
- clear.click(lambda: None, None, chatbot, queue=False)
173
 
174
  demo.queue().launch()
 
1
  # gradio
2
  import gradio as gr
3
  #import random
4
+ import time
5
  #boto3 for S3 access
6
  import boto3
7
  from botocore import UNSIGNED
 
15
  from langchain.embeddings import HuggingFaceHubEmbeddings
16
  # vectorestore
17
  from langchain.vectorstores import Chroma
18
+ #from langchain.vectorstores import FAISS
19
  # retrieval chain
20
+ #from langchain.chains import RetrievalQA
21
  from langchain.chains import RetrievalQAWithSourcesChain
22
  # prompt template
23
  from langchain.prompts import PromptTemplate
24
  from langchain.memory import ConversationBufferMemory
25
  # logging
26
  import logging
27
+ #import zipfile
28
+ # improve results with retriever
29
+ # from langchain.retrievers import ContextualCompressionRetriever
30
+ # from langchain.retrievers.document_compressors import LLMChainExtractor
31
+ # from langchain.retrievers.document_compressors import EmbeddingsFilter
32
+ # from langchain.retrievers.multi_query import MultiQueryRetriever
33
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
34
+ # reorder retrived documents
35
+ #from langchain.document_transformers import LongContextReorder
36
+ # github issues
37
+ from langchain.document_loaders import GitHubIssuesLoader
38
+ # debugging
39
+ from langchain.globals import set_verbose
40
+ # caching
41
+ from langchain.globals import set_llm_cache
42
+ #from langchain.cache import InMemoryCache
43
+ # We can do the same thing with a SQLite cache
44
+ from langchain.cache import SQLiteCache
45
+ #set_llm_cache(InMemoryCache())
46
+
47
+ set_verbose(True)
48
 
49
  # load .env variables
50
  config = load_dotenv(".env")
 
54
  VS_DESTINATION=os.getenv('VS_DESTINATION')
55
 
56
  # initialize Model config
57
+ model_id = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1", model_kwargs={
58
+ # "temperature":0.1,
59
  "max_new_tokens":1024,
60
  "repetition_penalty":1.2,
61
+ # "streaming": True,
62
+ # "return_full_text":True
63
  })
64
 
65
+ #model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
66
+ model_name = "sentence-transformers/all-mpnet-base-v2"
67
  embeddings = HuggingFaceHubEmbeddings(repo_id=model_name)
68
 
69
+ # remove old vectorstore
70
+ if os.path.exists(VS_DESTINATION):
71
+ os.remove(VS_DESTINATION)
72
+
73
+ # remove old sqlite cache
74
+ if os.path.exists('.langchain.sqlite'):
75
+ os.remove('.langchain.sqlite')
76
+
77
+ set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))
78
+
79
  # retrieve vectorsrore
80
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
81
 
82
  ## Chroma DB
83
  s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
84
+ # use the cached embeddings instead of embeddings to speed up re-retrival
85
  db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
86
  db.get()
87
 
 
93
  # FAISS_INDEX_PATH='./chroma_db/faiss_db_ray'
94
  # db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
95
 
96
+ # initialize the bm25 retriever and chroma/faiss retriever
97
+ bm25_retriever = BM25Retriever.
98
+ bm25_retriever.k = 2
99
 
100
+ retriever = db.as_retriever(search_type="mmr")#, search_kwargs={'k': 3, 'lambda_mult': 0.25})
101
+
102
+ # asks LLM to create 3 alternatives baed on user query
103
+ # multi_retriever = MultiQueryRetriever.from_llm(retriever=retriever, llm=model_id)
104
+
105
+ # asks LLM to extract relevant parts from retrieved documents
106
+ # compressor = LLMChainExtractor.from_llm(model_id)
107
+ # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=multi_retriever)
108
 
109
  global qa
110
  template = """
 
130
 
131
  # logging for the chain
132
  logging.basicConfig()
133
+ logging.getLogger("langchain.retrievers").setLevel(logging.INFO)
134
+ logging.getLogger("langchain.chains.qa_with_sources").setLevel(logging.INFO)
135
+
136
 
137
 
138
+ # qa = RetrievalQA.from_chain_type(llm=model_id, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
139
  # "verbose": True,
140
  # "memory": memory,
141
  # "prompt": prompt
142
  # }
143
  # )
144
+ qa = RetrievalQAWithSourcesChain.from_chain_type(llm=model_id, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
145
  "verbose": True,
146
  "memory": memory,
147
  "prompt": prompt,
 
149
  }
150
  )
151
 
152
+
153
+ #####
154
+ #
155
+ # Gradio fns
156
+ ####
157
 
158
  def add_text(history, text):
159
  history = history + [(text, None)]
 
161
 
162
  def bot(history):
163
  response = infer(history[-1][0], history)
164
+ sources = [doc.metadata.get("source") for doc in response['source_documents']]
 
 
165
  src_list = '\n'.join(sources)
166
  print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
 
167
 
168
+ # history[-1][1] = ""
169
+ # for character in response['answer']:
170
+ # #print_this:
171
+ # history[-1][1] += character
172
+ # time.sleep(0.01)
173
+ # yield history
174
+ history[-1][1] = print_this #response['answer']
175
+ return history
176
 
177
  def infer(question, history):
178
  query = question
 
184
  """
185
 
186
  title = """
187
+ <div style="text-align: center;max-width: 1920px;">
188
  <h1>Chat with your Documentation</h1>
189
+ <p style="text-align: center;">This is a privately hosten Docs AI Buddy, <br />
190
+ It will help you with any question regarding the documentation of Ray ;)</p>
191
  </div>
192
  """
193
 
194
+
195
+
196
  with gr.Blocks(css=css) as demo:
197
+ with gr.Column(min_width=900, elem_id="col-container"):
198
  gr.HTML(title)
199
  chatbot = gr.Chatbot([], elem_id="chatbot")
200
+ #with gr.Row():
201
+ # clear = gr.Button("Clear")
202
+
203
  with gr.Row():
204
  question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
205
+ with gr.Row():
206
+ clear = gr.ClearButton([chatbot, question])
207
+
208
  question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
209
  bot, chatbot, chatbot
210
  )
211
+ #clear.click(lambda: None, None, chatbot, queue=False)
212
 
213
  demo.queue().launch()