Mishab commited on
Commit
06aadc0
1 Parent(s): 0f6e5cd

Updated code

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. utils.py +8 -8
app.py CHANGED
@@ -30,7 +30,7 @@ from langchain.agents.agent_toolkits import create_conversational_retrieval_agen
30
  from langchain.utilities import SerpAPIWrapper
31
 
32
  from utils import build_embedding_model, build_llm
33
- from utils import load_ensemble_retriver, load_text_chunks, load_vectorstore, load_conversational_retrievel_chain
34
 
35
  load_dotenv()
36
  # Getting current timestamp to keep track of historical conversations
@@ -51,11 +51,11 @@ if "embeddings" not in st.session_state:
51
  if "vector_db" not in st.session_state:
52
  st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
53
 
54
- if "text_chunks" not in st.session_state:
55
- st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
56
 
57
  if "ensemble_retriver" not in st.session_state:
58
- st.session_state["ensemble_retriver"] = load_ensemble_retriver(text_chunks=st.session_state["text_chunks"], embeddings=st.session_state["embeddings"], chroma_vectorstore=st.session_state["vector_db"] )
59
 
60
  if "conversation_chain" not in st.session_state:
61
  st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["ensemble_retriver"], llm=st.session_state["llm"])
 
30
  from langchain.utilities import SerpAPIWrapper
31
 
32
  from utils import build_embedding_model, build_llm
33
+ from utils import load_ensemble_retriver,load_vectorstore, load_conversational_retrievel_chain
34
 
35
  load_dotenv()
36
  # Getting current timestamp to keep track of historical conversations
 
51
  if "vector_db" not in st.session_state:
52
  st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
53
 
54
+ # if "text_chunks" not in st.session_state:
55
+ # st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
56
 
57
  if "ensemble_retriver" not in st.session_state:
58
+ st.session_state["ensemble_retriver"] = load_ensemble_retriver(embeddings=st.session_state["embeddings"], chroma_vectorstore=st.session_state["vector_db"] )
59
 
60
  if "conversation_chain" not in st.session_state:
61
  st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["ensemble_retriver"], llm=st.session_state["llm"])
utils.py CHANGED
@@ -48,7 +48,7 @@ def build_llm():
48
  Loading OpenAI model
49
  '''
50
  # llm= OpenAI(temperature=0.2)
51
- llm= ChatOpenAI(temperature = 0, max_tokens=256)
52
  return llm
53
 
54
  def build_embedding_model():
@@ -253,15 +253,15 @@ def load_text_chunks(text_chunks_pkl_dir):
253
  pickle.dump(all_texts, file)
254
  print("Text chunks are created and cached")
255
 
256
- def load_ensemble_retriver(text_chunks, embeddings, chroma_vectorstore):
257
  """Load ensemble retiriever with BM25 and Chroma as individual retrievers"""
258
- bm25_retriever = BM25Retriever.from_documents(text_chunks)
259
- bm25_retriever.k = 2
260
- chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 3})
261
- ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
262
  logging.basicConfig()
263
  logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
264
- retriever_from_llm = MultiQueryRetriever.from_llm(retriever=ensemble_retriever,
265
  llm=ChatOpenAI(temperature=0))
266
  return retriever_from_llm
267
 
@@ -322,6 +322,6 @@ def load_conversational_retrievel_chain(retriever, llm):
322
  chain_type="stuff",
323
  retriever=retriever,
324
  return_source_documents=True,
325
- chain_type_kwargs={"prompt": prompt, "memory": memory},
326
  )
327
  return qa
 
48
  Loading OpenAI model
49
  '''
50
  # llm= OpenAI(temperature=0.2)
51
+ llm= ChatOpenAI(temperature = 0)
52
  return llm
53
 
54
  def build_embedding_model():
 
253
  pickle.dump(all_texts, file)
254
  print("Text chunks are created and cached")
255
 
256
+ def load_ensemble_retriver(embeddings, chroma_vectorstore):
257
  """Load ensemble retiriever with BM25 and Chroma as individual retrievers"""
258
+ # bm25_retriever = BM25Retriever.from_documents(text_chunks)
259
+ # bm25_retriever.k = 2
260
+ chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 10})
261
+ # ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
262
  logging.basicConfig()
263
  logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
264
+ retriever_from_llm = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
265
  llm=ChatOpenAI(temperature=0))
266
  return retriever_from_llm
267
 
 
322
  chain_type="stuff",
323
  retriever=retriever,
324
  return_source_documents=True,
325
+ chain_type_kwargs={"memory": memory},
326
  )
327
  return qa