ryanrwatkins commited on
Commit
1158d9c
1 Parent(s): 73cef93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -2
app.py CHANGED
@@ -167,7 +167,7 @@ def select_embeddings_model(LLM_service="HuggingFace"):
167
  return embeddings
168
 
169
  #embeddings_OpenAI = select_embeddings_model(LLM_service="OpenAI")
170
- #embeddings_google = select_embeddings_model(LLM_service="Google")
171
  embeddings_HuggingFace = select_embeddings_model(LLM_service="HuggingFace")
172
 
173
 
@@ -219,7 +219,7 @@ vector_store_OpenAI = Chroma(
219
  print("vector_store_OpenAI:",vector_store_OpenAI._collection.count(),"chunks.")
220
  """
221
  vector_store_google = Chroma(
222
- persist_directory = LOCAL_VECTOR_STORE_DIR.as_posix() + "/Vit_All_Google_Embeddings",
223
  embedding_function=embeddings_google)
224
  print("vector_store_google:",vector_store_google._collection.count(),"chunks.")
225
 
@@ -327,6 +327,117 @@ def CohereRerank_retriever(
327
 
328
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def instantiate_LLM(LLM_provider,api_key,temperature=0.5,top_p=0.95,model_name=None):
331
  """Instantiate LLM in Langchain.
332
  Parameters:
 
167
  return embeddings
168
 
169
  #embeddings_OpenAI = select_embeddings_model(LLM_service="OpenAI")
170
+ embeddings_google = select_embeddings_model(LLM_service="Google")
171
  embeddings_HuggingFace = select_embeddings_model(LLM_service="HuggingFace")
172
 
173
 
 
219
  print("vector_store_OpenAI:",vector_store_OpenAI._collection.count(),"chunks.")
220
  """
221
  vector_store_google = Chroma(
222
+ persist_directory = current_dir + "/Vit_All_Google_Embeddings",
223
  embedding_function=embeddings_google)
224
  print("vector_store_google:",vector_store_google._collection.count(),"chunks.")
225
 
 
327
 
328
 
329
 
330
+
331
+ def retrieval_blocks(
332
+ create_vectorstore=True,# if True a Chroma vectorstore is created, else the Chroma vectorstore will be loaded
333
+ LLM_service="HuggingFace",
334
+ vectorstore_name="Vit_All_HF_Embeddings",
335
+ chunk_size = 1600, chunk_overlap=200, # parameters of the RecursiveCharacterTextSplitter
336
+ retriever_type="Vectorstore_backed_retriever",
337
+ base_retriever_search_type="similarity", base_retriever_k=10, base_retriever_score_threshold=None,
338
+ compression_retriever_k=16,
339
+ cohere_api_key="***", cohere_model="rerank-multilingual-v2.0", cohere_top_n=8,
340
+ ):
341
+ """
342
+ Rertieval includes: document loaders, text splitter, vectorstore and retriever.
343
+
344
+ Parameters:
345
+ create_vectorstore (boolean): If True, a new Chroma vectorstore will be created. Otherwise, an existing vectorstore will be loaded.
346
+ LLM_service: OpenAI, Google or HuggingFace.
347
+ vectorstore_name (str): the name of the vectorstore.
348
+ chunk_size and chunk_overlap: parameters of the RecursiveCharacterTextSplitter, default = (1600,200).
349
+
350
+ retriever_type (str): in [Vectorstore_backed_retriever,Contextual_compression,Cohere_reranker]
351
+
352
+ base_retriever_search_type: search_type in ["similarity", "mmr", "similarity_score_threshold"], default = similarity.
353
+ base_retriever_k: The most similar vectors to retrieve (default k = 10).
354
+ base_retriever_score_threshold: score_threshold used by the base retriever, default = None.
355
+
356
+ compression_retriever_k: top k documents returned by the compression retriever, default=16
357
+
358
+ cohere_api_key: Cohere API key
359
+ cohere_model (str): The Cohere model can be either 'rerank-english-v2.0' or 'rerank-multilingual-v2.0', with the latter being the default.
360
+ cohere_top_n: top n results returned by Cohere rerank, default = 8.
361
+
362
+ Output:
363
+ retriever.
364
+ """
365
+ try:
366
+ # Create new Vectorstore (Chroma index)
367
+ if create_vectorstore:
368
+ # 1. load documents
369
+ documents = langchain_document_loader(current_dir)
370
+
371
+ # 2. Text Splitter: split documents to chunks
372
+ text_splitter = RecursiveCharacterTextSplitter(
373
+ separators = ["\n\n", "\n", " ", ""],
374
+ chunk_size = chunk_size,
375
+ chunk_overlap= chunk_overlap
376
+ )
377
+ chunks = text_splitter.split_documents(documents=documents)
378
+
379
+ # 3. Embeddings
380
+ embeddings = select_embeddings_model(LLM_service=LLM_service)
381
+
382
+ # 4. Vectorsore: create Chroma index
383
+ vector_store = create_vectorstore(
384
+ embeddings=embeddings,
385
+ documents = chunks,
386
+ vectorstore_name=vectorstore_name,
387
+ )
388
+
389
+ # 5. Load a Vectorstore (Chroma index)
390
+ else:
391
+ embeddings = select_embeddings_model(LLM_service=LLM_service)
392
+ vector_store = Chroma(
393
+ persist_directory = current_dir + "/" + vectorstore_name,
394
+ embedding_function=embeddings
395
+ )
396
+
397
+ # 6. base retriever: Vector store-backed retriever
398
+ base_retriever = Vectorstore_backed_retriever(
399
+ vector_store,
400
+ search_type=base_retriever_search_type,
401
+ k=base_retriever_k,
402
+ score_threshold=base_retriever_score_threshold
403
+ )
404
+ retriever = None
405
+ if retriever_type=="Vectorstore_backed_retriever":
406
+ retriever = base_retriever
407
+
408
+ # 7. Contextual Compression Retriever
409
+ if retriever_type=="Contextual_compression":
410
+ retriever = create_compression_retriever(
411
+ embeddings=embeddings,
412
+ base_retriever=base_retriever,
413
+ k=compression_retriever_k,
414
+ )
415
+
416
+ # 8. CohereRerank retriever
417
+ if retriever_type=="Cohere_reranker":
418
+ retriever = CohereRerank_retriever(
419
+ base_retriever=base_retriever,
420
+ cohere_api_key=cohere_api_key,
421
+ cohere_model=cohere_model,
422
+ top_n=cohere_top_n
423
+ )
424
+
425
+ print(f"\n{retriever_type} is created successfully!")
426
+ print(f"Relevant documents will be retrieved from vectorstore ({vectorstore_name}) which uses {LLM_service} embeddings \
427
+ and has {vector_store._collection.count()} chunks.")
428
+
429
+ return retriever
430
+ except Exception as e:
431
+ print(e)
432
+
433
+
434
+
435
+
436
+
437
+
438
+
439
+
440
+
441
  def instantiate_LLM(LLM_provider,api_key,temperature=0.5,top_p=0.95,model_name=None):
442
  """Instantiate LLM in Langchain.
443
  Parameters: