bstraehle commited on
Commit
772eaa5
·
1 Parent(s): 44a256c

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +12 -12
rag.py CHANGED
@@ -61,41 +61,41 @@ def split_documents(config, docs):
61
 
62
  return text_splitter.split_documents(docs)
63
 
64
- def embed_store_documents_chroma(chunks):
65
  Chroma.from_documents(documents = chunks,
66
  embedding = OpenAIEmbeddings(disallowed_special = ()),
67
  persist_directory = CHROMA_DIR)
68
 
69
- def embed_store_documents_mongodb(chunks):
70
  MongoDBAtlasVectorSearch.from_documents(documents = chunks,
71
  embedding = OpenAIEmbeddings(disallowed_special = ()),
72
  collection = collection,
73
  index_name = MONGODB_INDEX_NAME)
74
 
75
- def run_rag_batch(config):
76
  docs = load_documents()
77
 
78
  chunks = split_documents(config, docs)
79
 
80
- embed_store_documents_chroma(chunks)
81
- embed_store_documents_mongodb(chunks)
82
 
83
- def retrieve_documents_chroma():
84
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
  persist_directory = CHROMA_DIR)
86
 
87
- def retrieve_documents_mongodb():
88
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
  OpenAIEmbeddings(disallowed_special = ()),
91
  index_name = MONGODB_INDEX_NAME)
92
-
93
  def get_llm(config, openai_api_key):
94
  return ChatOpenAI(model_name = config["model_name"],
95
  openai_api_key = openai_api_key,
96
  temperature = config["temperature"])
97
 
98
- def run_llm_chain(config, openai_api_key, prompt):
99
  llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
100
  prompt = LLM_CHAIN_PROMPT)
101
 
@@ -104,13 +104,13 @@ def run_llm_chain(config, openai_api_key, prompt):
104
 
105
  return completion, llm_chain, cb
106
 
107
- def run_rag_chain(config, openai_api_key, rag_option, prompt):
108
  llm = get_llm(config, openai_api_key)
109
 
110
  if (rag_option == RAG_CHROMA):
111
- db = retrieve_documents_chroma()
112
  elif (rag_option == RAG_MONGODB):
113
- db = retrieve_documents_mongodb()
114
 
115
  rag_chain = RetrievalQA.from_chain_type(llm,
116
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
 
61
 
62
  return text_splitter.split_documents(docs)
63
 
64
+ def store_chroma(chunks):
65
  Chroma.from_documents(documents = chunks,
66
  embedding = OpenAIEmbeddings(disallowed_special = ()),
67
  persist_directory = CHROMA_DIR)
68
 
69
+ def store_mongodb(chunks):
70
  MongoDBAtlasVectorSearch.from_documents(documents = chunks,
71
  embedding = OpenAIEmbeddings(disallowed_special = ()),
72
  collection = collection,
73
  index_name = MONGODB_INDEX_NAME)
74
 
75
+ def rag_ingestion(config):
76
  docs = load_documents()
77
 
78
  chunks = split_documents(config, docs)
79
 
80
+ store_chroma(chunks)
81
+ store_mongodb(chunks)
82
 
83
+ def retrieve_chroma():
84
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
  persist_directory = CHROMA_DIR)
86
 
87
+ def retrieve_mongodb():
88
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
  OpenAIEmbeddings(disallowed_special = ()),
91
  index_name = MONGODB_INDEX_NAME)
92
+
93
  def get_llm(config, openai_api_key):
94
  return ChatOpenAI(model_name = config["model_name"],
95
  openai_api_key = openai_api_key,
96
  temperature = config["temperature"])
97
 
98
+ def llm_chain(config, openai_api_key, prompt):
99
  llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
100
  prompt = LLM_CHAIN_PROMPT)
101
 
 
104
 
105
  return completion, llm_chain, cb
106
 
107
+ def rag_chain(config, openai_api_key, rag_option, prompt):
108
  llm = get_llm(config, openai_api_key)
109
 
110
  if (rag_option == RAG_CHROMA):
111
+ db = retrieve_chroma()
112
  elif (rag_option == RAG_MONGODB):
113
+ db = retrieve_mongodb()
114
 
115
  rag_chain = RetrievalQA.from_chain_type(llm,
116
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,