Update rag.py
Browse files
rag.py
CHANGED
@@ -79,11 +79,11 @@ def rag_batch(config):
|
|
79 |
document_storage_chroma(chunks)
|
80 |
document_storage_mongodb(chunks)
|
81 |
|
82 |
-
def document_retrieval_chroma():
|
83 |
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
84 |
persist_directory = CHROMA_DIR)
|
85 |
|
86 |
-
def document_retrieval_mongodb():
|
87 |
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
|
88 |
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
|
89 |
OpenAIEmbeddings(disallowed_special = ()),
|
@@ -107,26 +107,17 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
|
|
107 |
llm = get_llm(config, openai_api_key)
|
108 |
|
109 |
if (rag_option == RAG_CHROMA):
|
110 |
-
db = document_retrieval_chroma()
|
111 |
elif (rag_option == RAG_MONGODB):
|
112 |
-
db = document_retrieval_mongodb()
|
113 |
-
|
114 |
-
###
|
115 |
-
retriever = db.as_retriever(search_kwargs = {"k": config["k"]})
|
116 |
-
retrieved_docs = retriever.invoke(prompt)
|
117 |
-
print(retrieved_docs[0].page_content)
|
118 |
-
print(retrieved_docs[1].page_content)
|
119 |
-
print(retrieved_docs[2].page_content)
|
120 |
-
###
|
121 |
-
|
122 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
123 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
124 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
125 |
return_source_documents = True,
|
126 |
verbose = False)
|
127 |
|
128 |
-
completion = rag_chain({"query": prompt}
|
129 |
print("###" + str(completion))
|
130 |
-
print("###" + str(completion["__run"]))
|
131 |
|
132 |
return completion, rag_chain
|
|
|
79 |
document_storage_chroma(chunks)
|
80 |
document_storage_mongodb(chunks)
|
81 |
|
82 |
+
def document_retrieval_chroma(llm):
|
83 |
return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
|
84 |
persist_directory = CHROMA_DIR)
|
85 |
|
86 |
+
def document_retrieval_mongodb(llm):
|
87 |
return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
|
88 |
MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
|
89 |
OpenAIEmbeddings(disallowed_special = ()),
|
|
|
107 |
llm = get_llm(config, openai_api_key)
|
108 |
|
109 |
if (rag_option == RAG_CHROMA):
|
110 |
+
db = document_retrieval_chroma(llm)
|
111 |
elif (rag_option == RAG_MONGODB):
|
112 |
+
db = document_retrieval_mongodb(llm)
|
113 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
rag_chain = RetrievalQA.from_chain_type(llm,
|
115 |
chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
|
116 |
retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
|
117 |
return_source_documents = True,
|
118 |
verbose = False)
|
119 |
|
120 |
+
completion = rag_chain({"query": prompt})
|
121 |
print("###" + str(completion))
|
|
|
122 |
|
123 |
return completion, rag_chain
|