Updated code
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ from langchain.agents.agent_toolkits import create_conversational_retrieval_agen
|
|
30 |
from langchain.utilities import SerpAPIWrapper
|
31 |
|
32 |
from utils import build_embedding_model, build_llm
|
33 |
-
from utils import load_ensemble_retriver,
|
34 |
|
35 |
load_dotenv()
|
36 |
# Getting current timestamp to keep track of historical conversations
|
@@ -51,11 +51,11 @@ if "embeddings" not in st.session_state:
|
|
51 |
if "vector_db" not in st.session_state:
|
52 |
st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
|
53 |
|
54 |
-
if "text_chunks" not in st.session_state:
|
55 |
-
|
56 |
|
57 |
if "ensemble_retriver" not in st.session_state:
|
58 |
-
st.session_state["ensemble_retriver"] = load_ensemble_retriver(
|
59 |
|
60 |
if "conversation_chain" not in st.session_state:
|
61 |
st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["ensemble_retriver"], llm=st.session_state["llm"])
|
|
|
30 |
from langchain.utilities import SerpAPIWrapper
|
31 |
|
32 |
from utils import build_embedding_model, build_llm
|
33 |
+
from utils import load_ensemble_retriver,load_vectorstore, load_conversational_retrievel_chain
|
34 |
|
35 |
load_dotenv()
|
36 |
# Getting current timestamp to keep track of historical conversations
|
|
|
51 |
if "vector_db" not in st.session_state:
|
52 |
st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"])
|
53 |
|
54 |
+
# if "text_chunks" not in st.session_state:
|
55 |
+
# st.session_state["text_chunks"] = load_text_chunks(text_chunks_pkl_dir=all_docs_pkl_directory)
|
56 |
|
57 |
if "ensemble_retriver" not in st.session_state:
|
58 |
+
st.session_state["ensemble_retriver"] = load_ensemble_retriver(embeddings=st.session_state["embeddings"], chroma_vectorstore=st.session_state["vector_db"] )
|
59 |
|
60 |
if "conversation_chain" not in st.session_state:
|
61 |
st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["ensemble_retriver"], llm=st.session_state["llm"])
|
utils.py
CHANGED
@@ -48,7 +48,7 @@ def build_llm():
|
|
48 |
Loading OpenAI model
|
49 |
'''
|
50 |
# llm= OpenAI(temperature=0.2)
|
51 |
-
llm= ChatOpenAI(temperature = 0
|
52 |
return llm
|
53 |
|
54 |
def build_embedding_model():
|
@@ -253,15 +253,15 @@ def load_text_chunks(text_chunks_pkl_dir):
|
|
253 |
pickle.dump(all_texts, file)
|
254 |
print("Text chunks are created and cached")
|
255 |
|
256 |
-
def load_ensemble_retriver(
|
257 |
"""Load ensemble retiriever with BM25 and Chroma as individual retrievers"""
|
258 |
-
bm25_retriever = BM25Retriever.from_documents(text_chunks)
|
259 |
-
bm25_retriever.k = 2
|
260 |
-
chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k":
|
261 |
-
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
|
262 |
logging.basicConfig()
|
263 |
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
|
264 |
-
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=
|
265 |
llm=ChatOpenAI(temperature=0))
|
266 |
return retriever_from_llm
|
267 |
|
@@ -322,6 +322,6 @@ def load_conversational_retrievel_chain(retriever, llm):
|
|
322 |
chain_type="stuff",
|
323 |
retriever=retriever,
|
324 |
return_source_documents=True,
|
325 |
-
chain_type_kwargs={"
|
326 |
)
|
327 |
return qa
|
|
|
48 |
Loading OpenAI model
|
49 |
'''
|
50 |
# llm= OpenAI(temperature=0.2)
|
51 |
+
llm= ChatOpenAI(temperature = 0)
|
52 |
return llm
|
53 |
|
54 |
def build_embedding_model():
|
|
|
253 |
pickle.dump(all_texts, file)
|
254 |
print("Text chunks are created and cached")
|
255 |
|
256 |
+
def load_ensemble_retriver(embeddings, chroma_vectorstore):
|
257 |
"""Load ensemble retiriever with BM25 and Chroma as individual retrievers"""
|
258 |
+
# bm25_retriever = BM25Retriever.from_documents(text_chunks)
|
259 |
+
# bm25_retriever.k = 2
|
260 |
+
chroma_retriever = chroma_vectorstore.as_retriever(search_kwargs={"k": 10})
|
261 |
+
# ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.3, 0.7])
|
262 |
logging.basicConfig()
|
263 |
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
|
264 |
+
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=chroma_retriever,
|
265 |
llm=ChatOpenAI(temperature=0))
|
266 |
return retriever_from_llm
|
267 |
|
|
|
322 |
chain_type="stuff",
|
323 |
retriever=retriever,
|
324 |
return_source_documents=True,
|
325 |
+
chain_type_kwargs={"memory": memory},
|
326 |
)
|
327 |
return qa
|