isayahc commited on
Commit
96f0f38
·
unverified ·
1 Parent(s): a2ede9f

changed constant to VECTOR_DATABASE_LOCATION

Browse files
config.py CHANGED
@@ -6,7 +6,7 @@ from langchain_huggingface import HuggingFaceEndpoint
6
  load_dotenv()
7
 
8
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
9
- PERSIST_DIRECTORY = os.getenv('VECTOR_DATABASE_LOCATION')
10
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
11
  SEVEN_B_LLM_MODEL = os.getenv("SEVEN_B_LLM_MODEL")
12
  BERT_MODEL = os.getenv("BERT_MODEL")
 
6
  load_dotenv()
7
 
8
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
9
+ VECTOR_DATABASE_LOCATION = os.getenv('VECTOR_DATABASE_LOCATION')
10
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
11
  SEVEN_B_LLM_MODEL = os.getenv("SEVEN_B_LLM_MODEL")
12
  BERT_MODEL = os.getenv("BERT_MODEL")
rag_app/get_db_retriever.py CHANGED
@@ -10,19 +10,52 @@ from langchain.chains import RetrievalQA
10
  # prompt template
11
  from langchain.prompts import PromptTemplate
12
  from langchain.memory import ConversationBufferMemory
13
- from config import EMBEDDING_MODEL
14
 
15
 
16
- def get_db_retriever(vector_db:str=None):
17
- embeddings = HuggingFaceHubEmbeddings(repo_id=EMBEDDING_MODEL)
 
18
 
19
- if not vector_db:
20
- FAISS_INDEX_PATH='./vectorstore/py-faiss-multi-mpnet-500'
21
- else:
22
- FAISS_INDEX_PATH=vector_db
23
- db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
24
 
25
- retriever = db.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return retriever
28
 
 
10
  # prompt template
11
  from langchain.prompts import PromptTemplate
12
  from langchain.memory import ConversationBufferMemory
13
+ from config import EMBEDDING_MODEL, VECTOR_DATABASE_LOCATION
14
 
15
 
16
+ def get_db_retriever():
17
+ """
18
+ Creates and returns a retriever object based on a FAISS vector database.
19
 
20
+ This function initializes an embedding model and loads a pre-existing FAISS
21
+ vector database from a local location. It then creates a retriever from this
22
+ database.
 
 
23
 
24
+ Returns:
25
+ --------
26
+ retriever : langchain.vectorstores.FAISS.VectorStoreRetriever
27
+ A retriever object that can be used to fetch relevant documents from the
28
+ vector database.
29
+
30
+ Global Variables Used:
31
+ ----------------------
32
+ EMBEDDING_MODEL : str
33
+ The identifier for the Hugging Face Hub embedding model to be used.
34
+ VECTOR_DATABASE_LOCATION : str
35
+ The local path where the FAISS vector database is stored.
36
+
37
+ Dependencies:
38
+ -------------
39
+ - langchain_huggingface.HuggingFaceHubEmbeddings
40
+ - langchain_community.vectorstores.FAISS
41
 
42
+ Note:
43
+ -----
44
+ This function assumes that a FAISS vector database has already been created
45
+ and saved at the location specified by VECTOR_DATABASE_LOCATION.
46
+ """
47
+
48
+ # Initialize the embedding model
49
+ embeddings = HuggingFaceHubEmbeddings(repo_id=EMBEDDING_MODEL)
50
+
51
+ # Load the FAISS vector database from the local storage
52
+ db = FAISS.load_local(
53
+ VECTOR_DATABASE_LOCATION,
54
+ embeddings,
55
+ )
56
+
57
+ # Create and return a retriever from the loaded database
58
+ retriever = db.as_retriever()
59
+
60
  return retriever
61
 
rag_app/structured_tools/structured_tools.py CHANGED
@@ -13,9 +13,9 @@ from rag_app.utils.utils import (
13
  )
14
  import chromadb
15
  import os
16
- from config import db, PERSIST_DIRECTORY, EMBEDDING_MODEL
17
 
18
- if not os.path.exists(PERSIST_DIRECTORY):
19
  get_chroma_vs()
20
 
21
  @tool
@@ -24,7 +24,7 @@ def memory_search(query:str) -> str:
24
  This is your primary source to start your search with checking what you already have learned from the past, before going online."""
25
  # Since we have more than one collections we should change the name of this tool
26
  client = chromadb.PersistentClient(
27
- path=PERSIST_DIRECTORY,
28
  )
29
 
30
  collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
@@ -71,7 +71,7 @@ def knowledgeBase_search(query:str) -> str:
71
  # #collection_name=collection_name,
72
  # embedding_function=embedding_function,
73
  # )
74
- vector_db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding_function)
75
  retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={'k':5, 'fetch_k':10})
76
  # This is deprecated, changed to invoke
77
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
 
13
  )
14
  import chromadb
15
  import os
16
+ from config import db, VECTOR_DATABASE_LOCATION, EMBEDDING_MODEL
17
 
18
+ if not os.path.exists(VECTOR_DATABASE_LOCATION):
19
  get_chroma_vs()
20
 
21
  @tool
 
24
  This is your primary source to start your search with checking what you already have learned from the past, before going online."""
25
  # Since we have more than one collections we should change the name of this tool
26
  client = chromadb.PersistentClient(
27
+ path=VECTOR_DATABASE_LOCATION,
28
  )
29
 
30
  collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
 
71
  # #collection_name=collection_name,
72
  # embedding_function=embedding_function,
73
  # )
74
+ vector_db = Chroma(persist_directory=VECTOR_DATABASE_LOCATION, embedding_function=embedding_function)
75
  retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={'k':5, 'fetch_k':10})
76
  # This is deprecated, changed to invoke
77
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.