clementsan commited on
Commit
1ef8d7c
1 Parent(s): 2239106

Include ephemeral client and collection_name for chromadb

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -11,6 +11,9 @@ from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
13
 
 
 
 
14
  from transformers import AutoTokenizer
15
  import transformers
16
  import torch
@@ -50,11 +53,14 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
50
 
51
 
52
  # Create vector database
53
- def create_db(splits):
54
  embedding = HuggingFaceEmbeddings()
 
55
  vectordb = Chroma.from_documents(
56
  documents=splits,
57
  embedding=embedding,
 
 
58
  # persist_directory=default_persist_directory
59
  )
60
  return vectordb
@@ -147,16 +153,18 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
147
  # Create list of documents (when valid)
148
  #file_path = file_obj.name
149
  list_file_path = [x.name for x in list_file_obj if x is not None]
150
- # print('list_file_path', list_file_path)
 
 
151
  progress(0.25, desc="Loading document...")
152
  # Load document and create splits
153
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
154
  # Create or load Vector database
155
  progress(0.5, desc="Generating vector database...")
156
  # global vector_db
157
- vector_db = create_db(doc_splits)
158
  progress(0.9, desc="Done!")
159
- return vector_db, "Complete!"
160
 
161
 
162
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
@@ -211,6 +219,7 @@ def demo():
211
  with gr.Blocks(theme="base") as demo:
212
  vector_db = gr.State()
213
  qa_chain = gr.State()
 
214
 
215
  gr.Markdown(
216
  """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
@@ -270,7 +279,7 @@ def demo():
270
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
271
  db_btn.click(initialize_database, \
272
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
273
- outputs=[vector_db, db_progress])
274
  qachain_btn.click(initialize_LLM, \
275
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
276
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
 
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
13
 
14
+ from pathlib import Path
15
+ import chromadb
16
+
17
  from transformers import AutoTokenizer
18
  import transformers
19
  import torch
 
53
 
54
 
55
  # Create vector database
56
+ def create_db(splits, collection_name):
57
  embedding = HuggingFaceEmbeddings()
58
+ new_client = chromadb.EphemeralClient()
59
  vectordb = Chroma.from_documents(
60
  documents=splits,
61
  embedding=embedding,
62
+ client=new_client,
63
+ collection_name=collection_name,
64
  # persist_directory=default_persist_directory
65
  )
66
  return vectordb
 
153
  # Create list of documents (when valid)
154
  #file_path = file_obj.name
155
  list_file_path = [x.name for x in list_file_obj if x is not None]
156
+ collection_name = Path(list_file_path[0]).stem
157
+ # print('list_file_path: ', list_file_path)
158
+ # print('Collection name: ', collection_name)
159
  progress(0.25, desc="Loading document...")
160
  # Load document and create splits
161
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
162
  # Create or load Vector database
163
  progress(0.5, desc="Generating vector database...")
164
  # global vector_db
165
+ vector_db = create_db(doc_splits, collection_name)
166
  progress(0.9, desc="Done!")
167
+ return vector_db, collection_name, "Complete!"
168
 
169
 
170
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
219
  with gr.Blocks(theme="base") as demo:
220
  vector_db = gr.State()
221
  qa_chain = gr.State()
222
+ collection_name = gr.State()
223
 
224
  gr.Markdown(
225
  """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
 
279
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
280
  db_btn.click(initialize_database, \
281
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
282
+ outputs=[vector_db, collection_name, db_progress])
283
  qachain_btn.click(initialize_LLM, \
284
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
285
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \