Update app.py
Browse files
app.py
CHANGED
@@ -434,14 +434,26 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
|
434 |
return tokenizer, optimized_texts
|
435 |
|
436 |
# New preprocessing function
|
437 |
-
def optimize_query(query,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
multi_query_retriever = MultiQueryRetriever.from_llm(
|
439 |
-
retriever=
|
440 |
llm=llm
|
441 |
)
|
442 |
optimized_queries = multi_query_retriever.generate_queries(query)
|
443 |
return optimized_queries
|
444 |
-
|
445 |
# New postprocessing function
|
446 |
def rerank_results(results, query, reranker):
|
447 |
reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
|
@@ -495,7 +507,7 @@ def compare_embeddings(file, query, embedding_models, custom_embedding_model, sp
|
|
495 |
chunks = optimized_chunks
|
496 |
|
497 |
if use_query_optimization:
|
498 |
-
optimized_queries = optimize_query(query, query_optimization_model)
|
499 |
query = " ".join(optimized_queries)
|
500 |
|
501 |
results, search_time, vector_store, results_raw = search_embeddings(
|
|
|
434 |
return tokenizer, optimized_texts
|
435 |
|
436 |
# New preprocessing function
|
437 |
+
def optimize_query(query, llm_model, chunks, embedding_model, vector_store_type, search_type, top_k):
|
438 |
+
llm = HuggingFacePipeline.from_model_id(
|
439 |
+
model_id=llm_model,
|
440 |
+
task="text2text-generation",
|
441 |
+
model_kwargs={"do_sample": True, "temperature": 0, "max_new_tokens": 64},
|
442 |
+
)
|
443 |
+
|
444 |
+
# Create a temporary vector store for query optimization
|
445 |
+
temp_vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
|
446 |
+
|
447 |
+
# Create a retriever with the temporary vector store
|
448 |
+
temp_retriever = get_retriever(temp_vector_store, search_type, {"k": top_k})
|
449 |
+
|
450 |
multi_query_retriever = MultiQueryRetriever.from_llm(
|
451 |
+
retriever=temp_retriever,
|
452 |
llm=llm
|
453 |
)
|
454 |
optimized_queries = multi_query_retriever.generate_queries(query)
|
455 |
return optimized_queries
|
456 |
+
|
457 |
# New postprocessing function
|
458 |
def rerank_results(results, query, reranker):
|
459 |
reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
|
|
|
507 |
chunks = optimized_chunks
|
508 |
|
509 |
if use_query_optimization:
|
510 |
+
optimized_queries = optimize_query(query, query_optimization_model, chunks, embedding_model, vector_store_type, search_type, top_k)
|
511 |
query = " ".join(optimized_queries)
|
512 |
|
513 |
results, search_time, vector_store, results_raw = search_embeddings(
|