Chris4K commited on
Commit
2c85855
·
verified ·
1 Parent(s): 9a00e93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -10
app.py CHANGED
@@ -505,18 +505,64 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
505
 
506
  return tokenizer, optimized_texts
507
 
508
- # New postprocessing function
509
- def rerank_results(results, query, reranker):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  if not hasattr(reranker, 'rerank'):
511
  # For TextClassificationPipeline
512
- pairs = [[query, doc.page_content] for doc in results]
513
- scores = [pred['score'] for pred in reranker(pairs, function_to_apply='cross_entropy')]
514
- reranked_idx = np.argsort(scores)[::-1]
515
- return [results[i] for i in reranked_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  else:
517
- # For models with rerank method
518
- return reranker.rerank(query, [doc.page_content for doc in results])
519
-
 
 
 
 
520
  # Main Comparison Function
521
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
522
  all_results = []
@@ -660,7 +706,9 @@ def automated_testing(file, query, test_params, expected_result=None):
660
  chunks = optimized_chunks
661
 
662
  if params['use_query_optimization']:
663
- optimized_queries = optimize_query(query, params['query_optimization_model'])
 
 
664
  query = " ".join(optimized_queries)
665
 
666
  results, search_time, vector_store, results_raw = search_embeddings(
 
505
 
506
  return tokenizer, optimized_texts
507
 
508
+ import numpy as np
509
+ from transformers import TextClassificationPipeline
510
+ from typing import List, Union, Any
511
+
512
+ def rerank_results(
513
+ results: List[Any],
514
+ query: str,
515
+ reranker: Union[TextClassificationPipeline, Any]
516
+ ) -> List[Any]:
517
+ """
518
+ Rerank search results using either a TextClassificationPipeline or a custom reranker.
519
+
520
+ Args:
521
+ results: List of documents/results to rerank
522
+ query: Search query string
523
+ reranker: Either a HuggingFace TextClassificationPipeline or a custom reranker
524
+ with a rerank() method
525
+
526
+ Returns:
527
+ List of reranked results
528
+ """
529
+ if not results:
530
+ return results
531
+
532
  if not hasattr(reranker, 'rerank'):
533
  # For TextClassificationPipeline
534
+ try:
535
+ pairs = [[query, doc.page_content] for doc in results]
536
+
537
+ # Standard classification without specific function
538
+ predictions = reranker(pairs)
539
+
540
+ # Extract scores, defaulting to 'score' key but falling back to other common keys
541
+ scores = []
542
+ for pred in predictions:
543
+ if isinstance(pred, dict):
544
+ score = pred.get('score',
545
+ pred.get('probability',
546
+ pred.get('confidence', 0.0)))
547
+ else:
548
+ score = float(pred)
549
+ scores.append(score)
550
+
551
+ # Sort in descending order (higher scores = better matches)
552
+ reranked_idx = np.argsort(scores)[::-1]
553
+ return [results[i] for i in reranked_idx]
554
+
555
+ except Exception as e:
556
+ print(f"Warning: Reranking failed with error: {str(e)}")
557
+ return results
558
  else:
559
+ # For models with dedicated rerank method
560
+ try:
561
+ return reranker.rerank(query, [doc.page_content for doc in results])
562
+ except Exception as e:
563
+ print(f"Warning: Custom reranking failed with error: {str(e)}")
564
+ return results
565
+
566
  # Main Comparison Function
567
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
568
  all_results = []
 
706
  chunks = optimized_chunks
707
 
708
  if params['use_query_optimization']:
709
+ optimized_queries = optimize_query(query, params['query_optimization_model'], params['chunks'] , params['embedding_model'] , params['vector_store_type'] , params['search_type'] , params['top_k'] )
710
+
711
+ #optimized_queries = optimize_query(query, )
712
  query = " ".join(optimized_queries)
713
 
714
  results, search_time, vector_store, results_raw = search_embeddings(