Chris4K commited on
Commit
d78ad1e
1 Parent(s): 54a0f5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -434,14 +434,26 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
434
  return tokenizer, optimized_texts
435
 
436
  # New preprocessing function
437
- def optimize_query(query, llm):
 
 
 
 
 
 
 
 
 
 
 
 
438
  multi_query_retriever = MultiQueryRetriever.from_llm(
439
- retriever=get_retriever(vector_store, search_type, search_kwargs),
440
  llm=llm
441
  )
442
  optimized_queries = multi_query_retriever.generate_queries(query)
443
  return optimized_queries
444
-
445
  # New postprocessing function
446
  def rerank_results(results, query, reranker):
447
  reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
@@ -495,7 +507,7 @@ def compare_embeddings(file, query, embedding_models, custom_embedding_model, sp
495
  chunks = optimized_chunks
496
 
497
  if use_query_optimization:
498
- optimized_queries = optimize_query(query, query_optimization_model)
499
  query = " ".join(optimized_queries)
500
 
501
  results, search_time, vector_store, results_raw = search_embeddings(
 
434
  return tokenizer, optimized_texts
435
 
436
  # New preprocessing function
437
+ def optimize_query(query, llm_model, chunks, embedding_model, vector_store_type, search_type, top_k):
438
+ llm = HuggingFacePipeline.from_model_id(
439
+ model_id=llm_model,
440
+ task="text2text-generation",
441
+ model_kwargs={"do_sample": True, "temperature": 0, "max_new_tokens": 64},
442
+ )
443
+
444
+ # Create a temporary vector store for query optimization
445
+ temp_vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
446
+
447
+ # Create a retriever with the temporary vector store
448
+ temp_retriever = get_retriever(temp_vector_store, search_type, {"k": top_k})
449
+
450
  multi_query_retriever = MultiQueryRetriever.from_llm(
451
+ retriever=temp_retriever,
452
  llm=llm
453
  )
454
  optimized_queries = multi_query_retriever.generate_queries(query)
455
  return optimized_queries
456
+
457
  # New postprocessing function
458
  def rerank_results(results, query, reranker):
459
  reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
 
507
  chunks = optimized_chunks
508
 
509
  if use_query_optimization:
510
+ optimized_queries = optimize_query(query, query_optimization_model, chunks, embedding_model, vector_store_type, search_type, top_k)
511
  query = " ".join(optimized_queries)
512
 
513
  results, search_time, vector_store, results_raw = search_embeddings(