Chris4K commited on
Commit
4b5f1bf
·
verified ·
1 Parent(s): a9006e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -510,6 +510,10 @@ import numpy as np
510
  from transformers import TextClassificationPipeline
511
  from typing import List, Union, Any
512
 
 
 
 
 
513
  def rerank_results(
514
  results: List[Any],
515
  query: str,
@@ -522,23 +526,24 @@ def rerank_results(
522
  results: List of documents/results to rerank
523
  query: Search query string
524
  reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
525
- with a rerank() method
526
 
527
  Returns:
528
  List of reranked results
529
  """
530
  if not results:
531
  return results
532
-
533
  if not hasattr(reranker, 'rerank'):
534
  # For TextClassificationPipeline
535
  try:
 
536
  pairs = [[query, doc.page_content] for doc in results]
537
 
538
- # Standard classification without specific function
539
  predictions = reranker(pairs)
540
 
541
- # Extract scores, defaulting to 'score' key but falling back to other common keys
542
  scores = []
543
  for pred in predictions:
544
  if isinstance(pred, dict):
@@ -549,21 +554,23 @@ def rerank_results(
549
  score = float(pred)
550
  scores.append(score)
551
 
552
- # Sort in descending order (higher scores = better matches)
553
  reranked_idx = np.argsort(scores)[::-1]
 
 
554
  return [results[i] for i in reranked_idx]
555
 
556
  except Exception as e:
557
  print(f"Warning: Reranking failed with error: {str(e)}")
558
  return results
559
  else:
560
- # For models with dedicated rerank method
561
  try:
562
  return reranker.rerank(query, [doc.page_content for doc in results])
563
  except Exception as e:
564
  print(f"Warning: Custom reranking failed with error: {str(e)}")
565
  return results
566
-
567
  # Main Comparison Function
568
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
569
  all_results = []
 
510
  from transformers import TextClassificationPipeline
511
  from typing import List, Union, Any
512
 
513
+ import numpy as np
514
+ from transformers import pipeline, TextClassificationPipeline
515
+ from typing import List, Any, Union
516
+
517
  def rerank_results(
518
  results: List[Any],
519
  query: str,
 
526
  results: List of documents/results to rerank
527
  query: Search query string
528
  reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
529
+ with a rerank() method.
530
 
531
  Returns:
532
  List of reranked results
533
  """
534
  if not results:
535
  return results
536
+
537
  if not hasattr(reranker, 'rerank'):
538
  # For TextClassificationPipeline
539
  try:
540
+ # Create pairs of query and document content
541
  pairs = [[query, doc.page_content] for doc in results]
542
 
543
+ # Get predictions from the reranker pipeline
544
  predictions = reranker(pairs)
545
 
546
+ # Extract scores with proper fallback options
547
  scores = []
548
  for pred in predictions:
549
  if isinstance(pred, dict):
 
554
  score = float(pred)
555
  scores.append(score)
556
 
557
+ # Sort the results based on scores in descending order
558
  reranked_idx = np.argsort(scores)[::-1]
559
+
560
+ # Return reranked results based on the sorted indices
561
  return [results[i] for i in reranked_idx]
562
 
563
  except Exception as e:
564
  print(f"Warning: Reranking failed with error: {str(e)}")
565
  return results
566
  else:
567
+ # For custom rerankers with a dedicated rerank method
568
  try:
569
  return reranker.rerank(query, [doc.page_content for doc in results])
570
  except Exception as e:
571
  print(f"Warning: Custom reranking failed with error: {str(e)}")
572
  return results
573
+
574
  # Main Comparison Function
575
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
576
  all_results = []