bohmian commited on
Commit
cc8b84c
1 Parent(s): 88cbf6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -21
app.py CHANGED
@@ -172,7 +172,8 @@ def get_llm():
172
  )
173
  return llm
174
 
175
- @st.cache_data # only going to get this once instead of all the time when page refreshers
 
176
  def get_embeddings():
177
  with st.spinner(f'Getting HuggingFaceEmbeddings'):
178
  # We use HuggingFaceEmbeddings() as it is open source and free to use.
@@ -209,28 +210,33 @@ if not os.path.exists("chromadb/"):
209
  with st.spinner(f'Unzipping chromadb retrievers for all chunk sizes and overlaps, will take some time'):
210
  os.system("unzip chromadb.zip")
211
 
212
- persist_directory = f"chromadb/chromadb_esg_countries_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}"
213
- with st.spinner(f'Setting up pre-built chroma vector store'):
214
- chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings)
215
-
216
- # Initialize BM25 Retriever
217
- # Unlike Chroma (semantic) BM25 is a keyword-based algorithm that performs well on queries containing keywords without capturing the semantic meaning of the query terms,
218
- # hence there is no need to embed the text with HuggingFaceEmbeddings and it is relatively faster to set up.
219
- # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
220
- # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
221
- # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
222
- # A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap.
223
- bm25_retrievers = {} # to store retrievers for different countries
224
- with st.spinner(f'Setting up pre-built bm25 retrievers'):
225
- for country in countries:
226
- bm25_filename = f"bm25/bm25_esg_countries_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}.pickle"
227
- with open(bm25_filename, 'rb') as handle:
228
- bm25_retriever = pickle.load(handle)
229
- bm25_retrievers[country] = bm25_retriever
230
-
231
- # One retriever above is semantic based and the other is keyword based
232
  # Both retrievers will be used
233
  # Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  ################################ Tools for Agent to Use ################################
236
 
@@ -450,6 +456,14 @@ if page == "Chat Config":
450
  # to override existing data on new scraped data or new pdf uploaded
451
  if page == "Document, Retriever, Web Scraping Config":
452
  st.header(page)
 
 
 
 
 
 
 
 
453
 
454
 
455
  ################################ Main Chatbot Page ################################
 
172
  )
173
  return llm
174
 
175
+ @st.cache_data # only going to get this once instead of all the time when page refreshes
176
+ # for chromadb vectore store
177
  def get_embeddings():
178
  with st.spinner(f'Getting HuggingFaceEmbeddings'):
179
  # We use HuggingFaceEmbeddings() as it is open source and free to use.
 
210
  with st.spinner(f'Unzipping chromadb retrievers for all chunk sizes and overlaps, will take some time'):
211
  os.system("unzip chromadb.zip")
212
 
213
+
214
+ # One retriever below is semantic based (chromadb) and the other is keyword based (bm25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  # Both retrievers will be used
216
  # Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below
217
+ def get_retrievers():
218
+ persist_directory = f"chromadb/chromadb_esg_countries_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}"
219
+ with st.spinner(f'Setting up pre-built chroma vector store'):
220
+ chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings)
221
+
222
+ # Initialize BM25 Retriever
223
+ # Unlike Chroma (semantic) BM25 is a keyword-based algorithm that performs well on queries containing keywords without capturing the semantic meaning of the query terms,
224
+ # hence there is no need to embed the text with HuggingFaceEmbeddings and it is relatively faster to set up.
225
+ # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
226
+ # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
227
+ # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
228
+ # A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap.
229
+ bm25_retrievers = {} # to store retrievers for different countries
230
+ with st.spinner(f'Setting up pre-built bm25 retrievers'):
231
+ for country in countries:
232
+ bm25_filename = f"bm25/bm25_esg_countries_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}.pickle"
233
+ with open(bm25_filename, 'rb') as handle:
234
+ bm25_retriever = pickle.load(handle)
235
+ bm25_retrievers[country] = bm25_retriever
236
+
237
+ return chroma_db, bm25_retrievers
238
+
239
+ chroma_db, bm25_retrievers = get_retrievers()
240
 
241
  ################################ Tools for Agent to Use ################################
242
 
 
456
  # to override existing data on new scraped data or new pdf uploaded
457
  if page == "Document, Retriever, Web Scraping Config":
458
  st.header(page)
459
+ st.session_state['chunk_size'] = st.selectbox(
460
+ "Chunk Size",
461
+ options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000],
462
+ key="chunk_size").on_change(get_retrievers)
463
+ st.session_state['chunk_overlap'] = st.selectbox(
464
+ "Chunk Overlap",
465
+ options=[50, 100, 150, 200],
466
+ key="chunk_overlap").on_change(get_retrievers)
467
 
468
 
469
  ################################ Main Chatbot Page ################################