Chris4K commited on
Commit
6fd2acf
1 Parent(s): 60de941

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py CHANGED
@@ -531,6 +531,111 @@ def format_results(results, stats):
531
  formatted_results.append(result)
532
  return formatted_results
533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  # Gradio Interface
535
  def launch_interface(share=True):
536
  with gr.Blocks() as iface:
@@ -592,6 +697,51 @@ def launch_interface(share=True):
592
  outputs=[results_output, stats_output, plot_output]
593
  )
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  tutorial_md = """
597
  # Advanced Embedding Comparison Tool Tutorial
@@ -618,5 +768,33 @@ def launch_interface(share=True):
618
 
619
  iface.launch(share=share)
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  if __name__ == "__main__":
622
  launch_interface()
 
531
  formatted_results.append(result)
532
  return formatted_results
533
 
534
+
535
+ #####
536
+ from sklearn.model_selection import ParameterGrid
537
+ from tqdm import tqdm
538
+
539
+ # ... (previous code remains the same)
540
+
541
+ # New function for automated testing
542
+ def automated_testing(file, query, test_params):
543
+ all_results = []
544
+ all_stats = []
545
+
546
+ param_grid = ParameterGrid(test_params)
547
+
548
+ for params in tqdm(param_grid, desc="Running tests"):
549
+ chunks, embedding_model, num_tokens = process_files(
550
+ file.name if file else None,
551
+ params['model_type'],
552
+ params['model_name'],
553
+ params['split_strategy'],
554
+ params['chunk_size'],
555
+ params['overlap_size'],
556
+ params.get('custom_separators', None),
557
+ params['lang'],
558
+ params['apply_preprocessing'],
559
+ params.get('custom_tokenizer_file', None),
560
+ params.get('custom_tokenizer_model', None),
561
+ params.get('custom_tokenizer_vocab_size', 10000),
562
+ params.get('custom_tokenizer_special_tokens', None)
563
+ )
564
+
565
+ if params['optimize_vocab']:
566
+ tokenizer, optimized_chunks = optimize_vocabulary(chunks)
567
+ chunks = optimized_chunks
568
+
569
+ if params['use_query_optimization']:
570
+ optimized_queries = optimize_query(query, params['query_optimization_model'])
571
+ query = " ".join(optimized_queries)
572
+
573
+ results, search_time, vector_store, results_raw = search_embeddings(
574
+ chunks,
575
+ embedding_model,
576
+ params['vector_store_type'],
577
+ params['search_type'],
578
+ query,
579
+ params['top_k'],
580
+ params['lang'],
581
+ params['apply_phonetic'],
582
+ params['phonetic_weight']
583
+ )
584
+
585
+ if params['use_reranking']:
586
+ reranker = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2")
587
+ results_raw = rerank_results(results_raw, query, reranker)
588
+
589
+ stats = calculate_statistics(results_raw, search_time, vector_store, num_tokens, embedding_model, query, params['top_k'])
590
+ stats["model"] = f"{params['model_type']} - {params['model_name']}"
591
+ stats.update(params)
592
+
593
+ all_results.extend(format_results(results_raw, stats))
594
+ all_stats.append(stats)
595
+
596
+ return pd.DataFrame(all_results), pd.DataFrame(all_stats)
597
+
598
+ # Function to analyze results and propose best model and settings
599
+ def analyze_results(stats_df):
600
+ # Define weights for different metrics (adjust as needed)
601
+ metric_weights = {
602
+ 'search_time': -0.3, # Lower is better
603
+ 'result_diversity': 0.2,
604
+ 'rank_correlation': 0.3,
605
+ 'silhouette_score': 0.2
606
+ }
607
+
608
+ # Calculate weighted score for each configuration
609
+ stats_df['weighted_score'] = sum(stats_df[metric] * weight for metric, weight in metric_weights.items())
610
+
611
+ # Get the best configuration
612
+ best_config = stats_df.loc[stats_df['weighted_score'].idxmax()]
613
+
614
+ # Generate recommendations
615
+ recommendations = {
616
+ 'best_model': f"{best_config['model_type']} - {best_config['model_name']}",
617
+ 'best_settings': {
618
+ 'split_strategy': best_config['split_strategy'],
619
+ 'chunk_size': best_config['chunk_size'],
620
+ 'overlap_size': best_config['overlap_size'],
621
+ 'vector_store_type': best_config['vector_store_type'],
622
+ 'search_type': best_config['search_type'],
623
+ 'top_k': best_config['top_k'],
624
+ 'optimize_vocab': best_config['optimize_vocab'],
625
+ 'use_query_optimization': best_config['use_query_optimization'],
626
+ 'use_reranking': best_config['use_reranking']
627
+ },
628
+ 'performance_summary': {
629
+ 'search_time': best_config['search_time'],
630
+ 'result_diversity': best_config['result_diversity'],
631
+ 'rank_correlation': best_config['rank_correlation'],
632
+ 'silhouette_score': best_config['silhouette_score']
633
+ }
634
+ }
635
+
636
+ return recommendations
637
+ ####
638
+
639
  # Gradio Interface
640
  def launch_interface(share=True):
641
  with gr.Blocks() as iface:
 
697
  outputs=[results_output, stats_output, plot_output]
698
  )
699
 
700
+ ####
701
+ with gr.Tab("Automated"):
702
+ auto_file_input = gr.File(label="Upload File (Optional)")
703
+ auto_query_input = gr.Textbox(label="Search Query")
704
+ auto_model_types = gr.CheckboxGroup(
705
+ choices=["HuggingFace", "OpenAI", "Cohere"],
706
+ label="Model Types to Test"
707
+ )
708
+ auto_model_names = gr.TextArea(label="Model Names to Test (comma-separated)")
709
+ auto_split_strategies = gr.CheckboxGroup(
710
+ choices=["token", "recursive"],
711
+ label="Split Strategies to Test"
712
+ )
713
+ auto_chunk_sizes = gr.TextArea(label="Chunk Sizes to Test (comma-separated)")
714
+ auto_overlap_sizes = gr.TextArea(label="Overlap Sizes to Test (comma-separated)")
715
+ auto_vector_store_types = gr.CheckboxGroup(
716
+ choices=["FAISS", "Chroma"],
717
+ label="Vector Store Types to Test"
718
+ )
719
+ auto_search_types = gr.CheckboxGroup(
720
+ choices=["similarity", "mmr", "custom"],
721
+ label="Search Types to Test"
722
+ )
723
+ auto_top_k = gr.TextArea(label="Top K Values to Test (comma-separated)")
724
+ auto_optimize_vocab = gr.Checkbox(label="Test Vocabulary Optimization", value=True)
725
+ auto_use_query_optimization = gr.Checkbox(label="Test Query Optimization", value=True)
726
+ auto_use_reranking = gr.Checkbox(label="Test Reranking", value=True)
727
+
728
+ auto_results_output = gr.Dataframe(label="Automated Test Results", interactive=False)
729
+ auto_stats_output = gr.Dataframe(label="Automated Test Statistics", interactive=False)
730
+ recommendations_output = gr.JSON(label="Recommendations")
731
+
732
+ auto_submit_button = gr.Button("Run Automated Tests")
733
+ auto_submit_button.click(
734
+ fn=lambda *args: run_automated_tests_and_analyze(*args),
735
+ inputs=[
736
+ auto_file_input, auto_query_input, auto_model_types, auto_model_names,
737
+ auto_split_strategies, auto_chunk_sizes, auto_overlap_sizes,
738
+ auto_vector_store_types, auto_search_types, auto_top_k,
739
+ auto_optimize_vocab, auto_use_query_optimization, auto_use_reranking
740
+ ],
741
+ outputs=[auto_results_output, auto_stats_output, recommendations_output]
742
+ )
743
+ ###
744
+
745
 
746
  tutorial_md = """
747
  # Advanced Embedding Comparison Tool Tutorial
 
768
 
769
  iface.launch(share=share)
770
 
771
+ def run_automated_tests_and_analyze(*args):
772
+ file, query, model_types, model_names, split_strategies, chunk_sizes, overlap_sizes, \
773
+ vector_store_types, search_types, top_k_values, optimize_vocab, use_query_optimization, use_reranking = args
774
+
775
+ test_params = {
776
+ 'model_type': model_types,
777
+ 'model_name': [name.strip() for name in model_names.split(',')],
778
+ 'split_strategy': split_strategies,
779
+ 'chunk_size': [int(size.strip()) for size in chunk_sizes.split(',')],
780
+ 'overlap_size': [int(size.strip()) for size in overlap_sizes.split(',')],
781
+ 'vector_store_type': vector_store_types,
782
+ 'search_type': search_types,
783
+ 'top_k': [int(k.strip()) for k in top_k_values.split(',')],
784
+ 'lang': ['german'], # You can expand this if needed
785
+ 'apply_preprocessing': [True],
786
+ 'optimize_vocab': [optimize_vocab],
787
+ 'apply_phonetic': [True],
788
+ 'phonetic_weight': [0.3],
789
+ 'use_query_optimization': [use_query_optimization],
790
+ 'query_optimization_model': ['google/flan-t5-base'],
791
+ 'use_reranking': [use_reranking]
792
+ }
793
+
794
+ results_df, stats_df = automated_testing(file, query, test_params)
795
+ recommendations = analyze_results(stats_df)
796
+
797
+ return results_df, stats_df, recommendations
798
+
799
  if __name__ == "__main__":
800
  launch_interface()