petrojm commited on
Commit
a9bf317
·
1 Parent(s): 8bb5bf9

changes to app and document_retrieval

Browse files
Files changed (2) hide show
  1. app.py +9 -5
  2. src/document_retrieval.py +4 -4
app.py CHANGED
@@ -27,16 +27,20 @@ def handle_userinput(user_question, conversation_chain, history):
27
  else:
28
  return history, ""
29
 
30
- def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, save_location=None):
31
  try:
32
- document_retrieval = DocumentRetrieval()
 
 
 
 
33
  _, _, text_chunks = parse_doc_universal(doc=files)
34
  print(len(text_chunks))
35
  print(text_chunks[0])
36
  embeddings = document_retrieval.load_embedding_model()
37
  collection_id = str(uuid.uuid4())
38
  collection_name = f"collection_{collection_id}"
39
- vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
40
  document_retrieval.init_retriever(vectorstore)
41
  conversation_chain = document_retrieval.get_qa_retrieval_chain()
42
  return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions."
@@ -57,7 +61,7 @@ with gr.Blocks() as demo:
57
 
58
  gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")
59
 
60
- #api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
61
 
62
  # Step 1: Add PDF file
63
  gr.Markdown("## 1️⃣ Upload PDF")
@@ -71,7 +75,7 @@ with gr.Blocks() as demo:
71
  gr.Markdown(caution_text)
72
 
73
  # Preprocessing events
74
- process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=20)
75
 
76
  # Step 3: Chat with your data
77
  gr.Markdown("## 3️⃣ Chat with your document")
 
27
  else:
28
  return history, ""
29
 
30
+ def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, api_key=None):
31
  try:
32
+ if api_key:
33
+ sambanova_api_key = api_key
34
+ else:
35
+ sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
36
+ document_retrieval = DocumentRetrieval(sambanova_api_key)
37
  _, _, text_chunks = parse_doc_universal(doc=files)
38
  print(len(text_chunks))
39
  print(text_chunks[0])
40
  embeddings = document_retrieval.load_embedding_model()
41
  collection_id = str(uuid.uuid4())
42
  collection_name = f"collection_{collection_id}"
43
+ vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=None, collection_name=collection_name)
44
  document_retrieval.init_retriever(vectorstore)
45
  conversation_chain = document_retrieval.get_qa_retrieval_chain()
46
  return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions."
 
61
 
62
  gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")
63
 
64
+ api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
65
 
66
  # Step 1: Add PDF file
67
  gr.Markdown("## 1️⃣ Upload PDF")
 
75
  gr.Markdown(caution_text)
76
 
77
  # Preprocessing events
78
+ process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain, api_key], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=20)
79
 
80
  # Step 3: Chat with your data
81
  gr.Markdown("## 3️⃣ Chat with your document")
src/document_retrieval.py CHANGED
@@ -124,7 +124,7 @@ class RetrievalQAChain(Chain):
124
 
125
 
126
  class DocumentRetrieval:
127
- def __init__(self):
128
  self.vectordb = VectorDb()
129
  config_info = self.get_config_info()
130
  self.api_info = config_info[0]
@@ -134,7 +134,7 @@ class DocumentRetrieval:
134
  self.prompts = config_info[4]
135
  self.prod_mode = config_info[5]
136
  self.retriever = None
137
- self.llm = self.set_llm()
138
 
139
  def get_config_info(self):
140
  """
@@ -152,7 +152,7 @@ class DocumentRetrieval:
152
 
153
  return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
154
 
155
- def set_llm(self):
156
  #if self.prod_mode:
157
  # sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
158
  #else:
@@ -161,7 +161,7 @@ class DocumentRetrieval:
161
  # else:
162
  # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
163
 
164
- sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
165
 
166
  llm = APIGateway.load_llm(
167
  type=self.api_info,
 
124
 
125
 
126
  class DocumentRetrieval:
127
+ def __init__(self, sambanova_api_key):
128
  self.vectordb = VectorDb()
129
  config_info = self.get_config_info()
130
  self.api_info = config_info[0]
 
134
  self.prompts = config_info[4]
135
  self.prod_mode = config_info[5]
136
  self.retriever = None
137
+ self.llm = self.set_llm(sambanova_api_key)
138
 
139
  def get_config_info(self):
140
  """
 
152
 
153
  return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
154
 
155
+ def set_llm(self, sambanova_api_key):
156
  #if self.prod_mode:
157
  # sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
158
  #else:
 
161
  # else:
162
  # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
163
 
164
+ #sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
165
 
166
  llm = APIGateway.load_llm(
167
  type=self.api_info,