bubuuunel commited on
Commit
97b6d79
1 Parent(s): 4649e3e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -34,8 +34,8 @@ client = OpenAI(
34
  embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
35
  # Load the persisted vectorDB
36
  collection_name = 'Dataset-10k'
37
- persisted_vectordb_location = "./bubuuunel/RAG10K"
38
- reportdb = Chroma(
39
  collection_name=collection_name,
40
  persist_directory='./dataset_db',
41
  embedding_function=embedding_model
@@ -91,10 +91,9 @@ def predict(user_input,company):
91
  }
92
 
93
  filter = "dataset/"+company+"-10-k-2023.pdf"
94
- retreiver = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
95
-
96
  # Create context_for_query
97
- relevant_document_chunks = retriever.get_relevant_documents(user_question)
98
  context_list = [d.page_content for d in relevant_document_chunks]
99
  context_for_query = ". ".join(context_list)
100
 
@@ -106,7 +105,7 @@ def predict(user_input,company):
106
  )
107
  }
108
  ]
109
-
110
  # Create messages
111
  try:
112
  response = client.chat.completions.create(
@@ -117,14 +116,14 @@ def predict(user_input,company):
117
 
118
  prediction = response.choices[0].message.content.strip()
119
  except Exception as e:
120
-
121
  prediction = f'Sorry, I encountered the following error: \n {e}'
122
-
123
 
124
 
125
  # Get response from the LLM
126
  prediction = response.choices[0].message.content.strip()
127
-
128
  # While the prediction is made, log both the inputs and outputs to a local log file
129
  # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
130
  # access
@@ -145,9 +144,9 @@ def predict(user_input,company):
145
  # Set-up the Gradio UI
146
  user_input = gr.Textbox (label = 'Query')
147
  company_input = gr.Radio(
148
- ['aws','google','IBM','Meta','msft'],
149
  label = 'company'
150
- )
151
 
152
  model_output = gr.Textbox (label = 'Response')
153
 
@@ -162,7 +161,7 @@ model_output = gr.Textbox (label = 'Response')
162
  demo = gr.Interface(
163
  fn=predict,
164
  inputs=[user_input,company_input],
165
- outputs=model_output,
166
  title="RAG on 10k-reports",
167
  description="This API allows you to query on annaul reports",
168
  concurrency_limit=16
 
34
  embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
35
  # Load the persisted vectorDB
36
  collection_name = 'Dataset-10k'
37
+
38
+ dataset_db = Chroma(
39
  collection_name=collection_name,
40
  persist_directory='./dataset_db',
41
  embedding_function=embedding_model
 
91
  }
92
 
93
  filter = "dataset/"+company+"-10-k-2023.pdf"
94
+
 
95
  # Create context_for_query
96
+ relevant_document_chunks = dataset_db.similarity_search(user_question, k=5, filter = {"source":"dataset/google-10-k-2023.pdf"})
97
  context_list = [d.page_content for d in relevant_document_chunks]
98
  context_for_query = ". ".join(context_list)
99
 
 
105
  )
106
  }
107
  ]
108
+
109
  # Create messages
110
  try:
111
  response = client.chat.completions.create(
 
116
 
117
  prediction = response.choices[0].message.content.strip()
118
  except Exception as e:
119
+
120
  prediction = f'Sorry, I encountered the following error: \n {e}'
121
+
122
 
123
 
124
  # Get response from the LLM
125
  prediction = response.choices[0].message.content.strip()
126
+
127
  # While the prediction is made, log both the inputs and outputs to a local log file
128
  # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
129
  # access
 
144
  # Set-up the Gradio UI
145
  user_input = gr.Textbox (label = 'Query')
146
  company_input = gr.Radio(
147
+ ['aws','google','IBM','Meta','msft'],
148
  label = 'company'
149
+ )
150
 
151
  model_output = gr.Textbox (label = 'Response')
152
 
 
161
  demo = gr.Interface(
162
  fn=predict,
163
  inputs=[user_input,company_input],
164
+ outputs=prediction,
165
  title="RAG on 10k-reports",
166
  description="This API allows you to query on annaul reports",
167
  concurrency_limit=16