bstraehle commited on
Commit
c8a9d42
1 Parent(s): ea50aaa

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +15 -15
rag.py CHANGED
@@ -37,7 +37,7 @@ RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"], tem
37
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
38
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
39
 
40
- def document_loading():
41
  docs = []
42
 
43
  # PDF
@@ -55,36 +55,36 @@ def document_loading():
55
 
56
  return docs
57
 
58
- def document_splitting(config, docs):
59
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
60
  chunk_size = config["chunk_size"])
61
 
62
  return text_splitter.split_documents(docs)
63
 
64
- def document_storage_chroma(chunks):
65
  Chroma.from_documents(documents = chunks,
66
  embedding = OpenAIEmbeddings(disallowed_special = ()),
67
  persist_directory = CHROMA_DIR)
68
 
69
- def document_storage_mongodb(chunks):
70
  MongoDBAtlasVectorSearch.from_documents(documents = chunks,
71
  embedding = OpenAIEmbeddings(disallowed_special = ()),
72
  collection = collection,
73
  index_name = MONGODB_INDEX_NAME)
74
 
75
- def rag_batch(config):
76
- docs = document_loading()
77
 
78
- chunks = document_splitting(config, docs)
79
 
80
- document_storage_chroma(chunks)
81
- document_storage_mongodb(chunks)
82
 
83
- def document_retrieval_chroma():
84
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
  persist_directory = CHROMA_DIR)
86
 
87
- def document_retrieval_mongodb():
88
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
  OpenAIEmbeddings(disallowed_special = ()),
@@ -95,7 +95,7 @@ def get_llm(config, openai_api_key):
95
  openai_api_key = openai_api_key,
96
  temperature = config["temperature"])
97
 
98
- def llm_chain(config, openai_api_key, prompt):
99
  llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
100
  prompt = LLM_CHAIN_PROMPT)
101
 
@@ -104,13 +104,13 @@ def llm_chain(config, openai_api_key, prompt):
104
 
105
  return completion, llm_chain, cb
106
 
107
- def rag_chain(config, openai_api_key, rag_option, prompt):
108
  llm = get_llm(config, openai_api_key)
109
 
110
  if (rag_option == RAG_CHROMA):
111
- db = document_retrieval_chroma()
112
  elif (rag_option == RAG_MONGODB):
113
- db = document_retrieval_mongodb()
114
 
115
  rag_chain = RetrievalQA.from_chain_type(llm,
116
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
 
37
  client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
38
  collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
39
 
40
+ def load_documents():
41
  docs = []
42
 
43
  # PDF
 
55
 
56
  return docs
57
 
58
+ def split_documents(config, docs):
59
  text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = config["chunk_overlap"],
60
  chunk_size = config["chunk_size"])
61
 
62
  return text_splitter.split_documents(docs)
63
 
64
+ def embed_store_documents_chroma(chunks):
65
  Chroma.from_documents(documents = chunks,
66
  embedding = OpenAIEmbeddings(disallowed_special = ()),
67
  persist_directory = CHROMA_DIR)
68
 
69
+ def embed_store_documents_mongodb(chunks):
70
  MongoDBAtlasVectorSearch.from_documents(documents = chunks,
71
  embedding = OpenAIEmbeddings(disallowed_special = ()),
72
  collection = collection,
73
  index_name = MONGODB_INDEX_NAME)
74
 
75
+ def run_rag_batch(config):
76
+ docs = load_documents()
77
 
78
+ chunks = split_documents(config, docs)
79
 
80
+ embed_store_documents_chroma(chunks)
81
+ embed_store_documents_mongodb(chunks)
82
 
83
+ def retrieve_documents_chroma():
84
  return Chroma(embedding_function = OpenAIEmbeddings(disallowed_special = ()),
85
  persist_directory = CHROMA_DIR)
86
 
87
+ def retrieve_documents_mongodb():
88
  return MongoDBAtlasVectorSearch.from_connection_string(MONGODB_ATLAS_CLUSTER_URI,
89
  MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
90
  OpenAIEmbeddings(disallowed_special = ()),
 
95
  openai_api_key = openai_api_key,
96
  temperature = config["temperature"])
97
 
98
+ def run_llm_chain(config, openai_api_key, prompt):
99
  llm_chain = LLMChain(llm = get_llm(config, openai_api_key),
100
  prompt = LLM_CHAIN_PROMPT)
101
 
 
104
 
105
  return completion, llm_chain, cb
106
 
107
+ def run_rag_chain(config, openai_api_key, rag_option, prompt):
108
  llm = get_llm(config, openai_api_key)
109
 
110
  if (rag_option == RAG_CHROMA):
111
+ db = retrieve_documents_chroma()
112
  elif (rag_option == RAG_MONGODB):
113
+ db = retrieve_documents_mongodb()
114
 
115
  rag_chain = RetrievalQA.from_chain_type(llm,
116
  chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},