Update app.py
Browse files
app.py
CHANGED
@@ -510,6 +510,10 @@ import numpy as np
|
|
510 |
from transformers import TextClassificationPipeline
|
511 |
from typing import List, Union, Any
|
512 |
|
|
|
|
|
|
|
|
|
513 |
def rerank_results(
|
514 |
results: List[Any],
|
515 |
query: str,
|
@@ -522,23 +526,24 @@ def rerank_results(
|
|
522 |
results: List of documents/results to rerank
|
523 |
query: Search query string
|
524 |
reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
|
525 |
-
with a rerank() method
|
526 |
|
527 |
Returns:
|
528 |
List of reranked results
|
529 |
"""
|
530 |
if not results:
|
531 |
return results
|
532 |
-
|
533 |
if not hasattr(reranker, 'rerank'):
|
534 |
# For TextClassificationPipeline
|
535 |
try:
|
|
|
536 |
pairs = [[query, doc.page_content] for doc in results]
|
537 |
|
538 |
-
#
|
539 |
predictions = reranker(pairs)
|
540 |
|
541 |
-
# Extract scores
|
542 |
scores = []
|
543 |
for pred in predictions:
|
544 |
if isinstance(pred, dict):
|
@@ -549,21 +554,23 @@ def rerank_results(
|
|
549 |
score = float(pred)
|
550 |
scores.append(score)
|
551 |
|
552 |
-
# Sort
|
553 |
reranked_idx = np.argsort(scores)[::-1]
|
|
|
|
|
554 |
return [results[i] for i in reranked_idx]
|
555 |
|
556 |
except Exception as e:
|
557 |
print(f"Warning: Reranking failed with error: {str(e)}")
|
558 |
return results
|
559 |
else:
|
560 |
-
# For
|
561 |
try:
|
562 |
return reranker.rerank(query, [doc.page_content for doc in results])
|
563 |
except Exception as e:
|
564 |
print(f"Warning: Custom reranking failed with error: {str(e)}")
|
565 |
return results
|
566 |
-
|
567 |
# Main Comparison Function
|
568 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
569 |
all_results = []
|
|
|
510 |
from transformers import TextClassificationPipeline
|
511 |
from typing import List, Union, Any
|
512 |
|
513 |
+
import numpy as np
|
514 |
+
from transformers import pipeline, TextClassificationPipeline
|
515 |
+
from typing import List, Any, Union
|
516 |
+
|
517 |
def rerank_results(
|
518 |
results: List[Any],
|
519 |
query: str,
|
|
|
526 |
results: List of documents/results to rerank
|
527 |
query: Search query string
|
528 |
reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
|
529 |
+
with a rerank() method.
|
530 |
|
531 |
Returns:
|
532 |
List of reranked results
|
533 |
"""
|
534 |
if not results:
|
535 |
return results
|
536 |
+
|
537 |
if not hasattr(reranker, 'rerank'):
|
538 |
# For TextClassificationPipeline
|
539 |
try:
|
540 |
+
# Create pairs of query and document content
|
541 |
pairs = [[query, doc.page_content] for doc in results]
|
542 |
|
543 |
+
# Get predictions from the reranker pipeline
|
544 |
predictions = reranker(pairs)
|
545 |
|
546 |
+
# Extract scores with proper fallback options
|
547 |
scores = []
|
548 |
for pred in predictions:
|
549 |
if isinstance(pred, dict):
|
|
|
554 |
score = float(pred)
|
555 |
scores.append(score)
|
556 |
|
557 |
+
# Sort the results based on scores in descending order
|
558 |
reranked_idx = np.argsort(scores)[::-1]
|
559 |
+
|
560 |
+
# Return reranked results based on the sorted indices
|
561 |
return [results[i] for i in reranked_idx]
|
562 |
|
563 |
except Exception as e:
|
564 |
print(f"Warning: Reranking failed with error: {str(e)}")
|
565 |
return results
|
566 |
else:
|
567 |
+
# For custom rerankers with a dedicated rerank method
|
568 |
try:
|
569 |
return reranker.rerank(query, [doc.page_content for doc in results])
|
570 |
except Exception as e:
|
571 |
print(f"Warning: Custom reranking failed with error: {str(e)}")
|
572 |
return results
|
573 |
+
|
574 |
# Main Comparison Function
|
575 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
576 |
all_results = []
|