bstraehle commited on
Commit
96012de
·
1 Parent(s): 9c86fb0

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +6 -15
rag.py CHANGED
@@ -79,11 +79,11 @@ def rag_batch(config):
79
  document_storage_chroma(chunks)
80
  document_storage_mongodb(chunks)
81
 
82
- def document_retrieval_chroma():
83
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
84
  persist_directory = CHROMA_DIR)
85
 
86
- def document_retrieval_mongodb():
87
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
88
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
89
  OpenAIEmbeddings(disallowed_special = ()),
@@ -107,26 +107,17 @@ def rag_chain(config, openai_api_key, rag_option, prompt):
107
  llm = get_llm(config, openai_api_key)
108
 
109
  if (rag_option == RAG_CHROMA):
110
- db = document_retrieval_chroma()
111
  elif (rag_option == RAG_MONGODB):
112
- db = document_retrieval_mongodb()
113
-
114
- ###
115
- retriever = db.as_retriever(search_kwargs = {"k": config["k"]})
116
- retrieved_docs = retriever.invoke(prompt)
117
- print(retrieved_docs[0].page_content)
118
- print(retrieved_docs[1].page_content)
119
- print(retrieved_docs[2].page_content)
120
- ###
121
-
122
  rag_chain = RetrievalQA.from_chain_type(llm,
123
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
124
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
125
  return_source_documents = True,
126
  verbose = False)
127
 
128
- completion = rag_chain({"query": prompt}, include_run_info = True)
129
  print("###" + str(completion))
130
- print("###" + str(completion["__run"]))
131
 
132
  return completion, rag_chain
 
79
  document_storage_chroma(chunks)
80
  document_storage_mongodb(chunks)
81
 
82
+ def document_retrieval_chroma(llm):
83
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
84
  persist_directory = CHROMA_DIR)
85
 
86
+ def document_retrieval_mongodb(llm):
87
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
88
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
89
  OpenAIEmbeddings(disallowed_special = ()),
 
107
  llm = get_llm(config, openai_api_key)
108
 
109
  if (rag_option == RAG_CHROMA):
110
+ db = document_retrieval_chroma(llm)
111
  elif (rag_option == RAG_MONGODB):
112
+ db = document_retrieval_mongodb(llm)
113
+
 
 
 
 
 
 
 
 
114
  rag_chain = RetrievalQA.from_chain_type(llm,
115
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
116
  retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
117
  return_source_documents = True,
118
  verbose = False)
119
 
120
+ completion = rag_chain({"query": prompt})
121
  print("###" + str(completion))
 
122
 
123
  return completion, rag_chain