bstraehle commited on
Commit
4e80daf
1 Parent(s): 5b6d867

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +5 -2
rag.py CHANGED
@@ -95,10 +95,13 @@ def llm_chain(config, openai_api_key, prompt):
95
 
96
  return completion, llm_chain
97
 
98
- def rag_chain(config, openai_api_key, prompt):
99
  llm = get_llm(config, openai_api_key)
100
 
101
- db = document_retrieval_chroma(llm, prompt)
 
 
 
102
 
103
  rag_chain = RetrievalQA.from_chain_type(llm,
104
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
 
95
 
96
  return completion, llm_chain
97
 
98
+ def rag_chain(config, openai_api_key, rag_option, prompt):
99
  llm = get_llm(config, openai_api_key)
100
 
101
+ if (rag_option == RAG_CHROMA):
102
+ db = document_retrieval_chroma(llm, prompt)
103
+ else:
104
+ db = document_retrieval_mongodb(llm, prompt)
105
 
106
  rag_chain = RetrievalQA.from_chain_type(llm,
107
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},