Update app.py
Browse files
app.py
CHANGED
@@ -505,18 +505,64 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
|
505 |
|
506 |
return tokenizer, optimized_texts
|
507 |
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
if not hasattr(reranker, 'rerank'):
|
511 |
# For TextClassificationPipeline
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
else:
|
517 |
-
# For models with rerank method
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
520 |
# Main Comparison Function
|
521 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
522 |
all_results = []
|
@@ -660,7 +706,9 @@ def automated_testing(file, query, test_params, expected_result=None):
|
|
660 |
chunks = optimized_chunks
|
661 |
|
662 |
if params['use_query_optimization']:
|
663 |
-
optimized_queries = optimize_query(query, params['query_optimization_model'])
|
|
|
|
|
664 |
query = " ".join(optimized_queries)
|
665 |
|
666 |
results, search_time, vector_store, results_raw = search_embeddings(
|
|
|
505 |
|
506 |
return tokenizer, optimized_texts
|
507 |
|
508 |
+
import numpy as np
|
509 |
+
from transformers import TextClassificationPipeline
|
510 |
+
from typing import List, Union, Any
|
511 |
+
|
512 |
+
def rerank_results(
|
513 |
+
results: List[Any],
|
514 |
+
query: str,
|
515 |
+
reranker: Union[TextClassificationPipeline, Any]
|
516 |
+
) -> List[Any]:
|
517 |
+
"""
|
518 |
+
Rerank search results using either a TextClassificationPipeline or a custom reranker.
|
519 |
+
|
520 |
+
Args:
|
521 |
+
results: List of documents/results to rerank
|
522 |
+
query: Search query string
|
523 |
+
reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
|
524 |
+
with a rerank() method
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
List of reranked results
|
528 |
+
"""
|
529 |
+
if not results:
|
530 |
+
return results
|
531 |
+
|
532 |
if not hasattr(reranker, 'rerank'):
|
533 |
# For TextClassificationPipeline
|
534 |
+
try:
|
535 |
+
pairs = [[query, doc.page_content] for doc in results]
|
536 |
+
|
537 |
+
# Standard classification without specific function
|
538 |
+
predictions = reranker(pairs)
|
539 |
+
|
540 |
+
# Extract scores, defaulting to 'score' key but falling back to other common keys
|
541 |
+
scores = []
|
542 |
+
for pred in predictions:
|
543 |
+
if isinstance(pred, dict):
|
544 |
+
score = pred.get('score',
|
545 |
+
pred.get('probability',
|
546 |
+
pred.get('confidence', 0.0)))
|
547 |
+
else:
|
548 |
+
score = float(pred)
|
549 |
+
scores.append(score)
|
550 |
+
|
551 |
+
# Sort in descending order (higher scores = better matches)
|
552 |
+
reranked_idx = np.argsort(scores)[::-1]
|
553 |
+
return [results[i] for i in reranked_idx]
|
554 |
+
|
555 |
+
except Exception as e:
|
556 |
+
print(f"Warning: Reranking failed with error: {str(e)}")
|
557 |
+
return results
|
558 |
else:
|
559 |
+
# For models with dedicated rerank method
|
560 |
+
try:
|
561 |
+
return reranker.rerank(query, [doc.page_content for doc in results])
|
562 |
+
except Exception as e:
|
563 |
+
print(f"Warning: Custom reranking failed with error: {str(e)}")
|
564 |
+
return results
|
565 |
+
|
566 |
# Main Comparison Function
|
567 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
568 |
all_results = []
|
|
|
706 |
chunks = optimized_chunks
|
707 |
|
708 |
if params['use_query_optimization']:
|
709 |
+
optimized_queries = optimize_query(query, params['query_optimization_model'], params['chunks'] , params['embedding_model'] , params['vector_store_type'] , params['search_type'] , params['top_k'] )
|
710 |
+
|
711 |
+
#optimized_queries = optimize_query(query, )
|
712 |
query = " ".join(optimized_queries)
|
713 |
|
714 |
results, search_time, vector_store, results_raw = search_embeddings(
|