Update app.py
Browse files
app.py
CHANGED
@@ -531,6 +531,111 @@ def format_results(results, stats):
|
|
531 |
formatted_results.append(result)
|
532 |
return formatted_results
|
533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
# Gradio Interface
|
535 |
def launch_interface(share=True):
|
536 |
with gr.Blocks() as iface:
|
@@ -592,6 +697,51 @@ def launch_interface(share=True):
|
|
592 |
outputs=[results_output, stats_output, plot_output]
|
593 |
)
|
594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
|
596 |
tutorial_md = """
|
597 |
# Advanced Embedding Comparison Tool Tutorial
|
@@ -618,5 +768,33 @@ def launch_interface(share=True):
|
|
618 |
|
619 |
iface.launch(share=share)
|
620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
if __name__ == "__main__":
|
622 |
launch_interface()
|
|
|
531 |
formatted_results.append(result)
|
532 |
return formatted_results
|
533 |
|
534 |
+
|
535 |
+
#####
|
536 |
+
from sklearn.model_selection import ParameterGrid
|
537 |
+
from tqdm import tqdm
|
538 |
+
|
539 |
+
# ... (previous code remains the same)
|
540 |
+
|
541 |
+
# New function for automated testing
|
542 |
+
def automated_testing(file, query, test_params):
|
543 |
+
all_results = []
|
544 |
+
all_stats = []
|
545 |
+
|
546 |
+
param_grid = ParameterGrid(test_params)
|
547 |
+
|
548 |
+
for params in tqdm(param_grid, desc="Running tests"):
|
549 |
+
chunks, embedding_model, num_tokens = process_files(
|
550 |
+
file.name if file else None,
|
551 |
+
params['model_type'],
|
552 |
+
params['model_name'],
|
553 |
+
params['split_strategy'],
|
554 |
+
params['chunk_size'],
|
555 |
+
params['overlap_size'],
|
556 |
+
params.get('custom_separators', None),
|
557 |
+
params['lang'],
|
558 |
+
params['apply_preprocessing'],
|
559 |
+
params.get('custom_tokenizer_file', None),
|
560 |
+
params.get('custom_tokenizer_model', None),
|
561 |
+
params.get('custom_tokenizer_vocab_size', 10000),
|
562 |
+
params.get('custom_tokenizer_special_tokens', None)
|
563 |
+
)
|
564 |
+
|
565 |
+
if params['optimize_vocab']:
|
566 |
+
tokenizer, optimized_chunks = optimize_vocabulary(chunks)
|
567 |
+
chunks = optimized_chunks
|
568 |
+
|
569 |
+
if params['use_query_optimization']:
|
570 |
+
optimized_queries = optimize_query(query, params['query_optimization_model'])
|
571 |
+
query = " ".join(optimized_queries)
|
572 |
+
|
573 |
+
results, search_time, vector_store, results_raw = search_embeddings(
|
574 |
+
chunks,
|
575 |
+
embedding_model,
|
576 |
+
params['vector_store_type'],
|
577 |
+
params['search_type'],
|
578 |
+
query,
|
579 |
+
params['top_k'],
|
580 |
+
params['lang'],
|
581 |
+
params['apply_phonetic'],
|
582 |
+
params['phonetic_weight']
|
583 |
+
)
|
584 |
+
|
585 |
+
if params['use_reranking']:
|
586 |
+
reranker = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
587 |
+
results_raw = rerank_results(results_raw, query, reranker)
|
588 |
+
|
589 |
+
stats = calculate_statistics(results_raw, search_time, vector_store, num_tokens, embedding_model, query, params['top_k'])
|
590 |
+
stats["model"] = f"{params['model_type']} - {params['model_name']}"
|
591 |
+
stats.update(params)
|
592 |
+
|
593 |
+
all_results.extend(format_results(results_raw, stats))
|
594 |
+
all_stats.append(stats)
|
595 |
+
|
596 |
+
return pd.DataFrame(all_results), pd.DataFrame(all_stats)
|
597 |
+
|
598 |
+
# Function to analyze results and propose best model and settings
|
599 |
+
def analyze_results(stats_df):
|
600 |
+
# Define weights for different metrics (adjust as needed)
|
601 |
+
metric_weights = {
|
602 |
+
'search_time': -0.3, # Lower is better
|
603 |
+
'result_diversity': 0.2,
|
604 |
+
'rank_correlation': 0.3,
|
605 |
+
'silhouette_score': 0.2
|
606 |
+
}
|
607 |
+
|
608 |
+
# Calculate weighted score for each configuration
|
609 |
+
stats_df['weighted_score'] = sum(stats_df[metric] * weight for metric, weight in metric_weights.items())
|
610 |
+
|
611 |
+
# Get the best configuration
|
612 |
+
best_config = stats_df.loc[stats_df['weighted_score'].idxmax()]
|
613 |
+
|
614 |
+
# Generate recommendations
|
615 |
+
recommendations = {
|
616 |
+
'best_model': f"{best_config['model_type']} - {best_config['model_name']}",
|
617 |
+
'best_settings': {
|
618 |
+
'split_strategy': best_config['split_strategy'],
|
619 |
+
'chunk_size': best_config['chunk_size'],
|
620 |
+
'overlap_size': best_config['overlap_size'],
|
621 |
+
'vector_store_type': best_config['vector_store_type'],
|
622 |
+
'search_type': best_config['search_type'],
|
623 |
+
'top_k': best_config['top_k'],
|
624 |
+
'optimize_vocab': best_config['optimize_vocab'],
|
625 |
+
'use_query_optimization': best_config['use_query_optimization'],
|
626 |
+
'use_reranking': best_config['use_reranking']
|
627 |
+
},
|
628 |
+
'performance_summary': {
|
629 |
+
'search_time': best_config['search_time'],
|
630 |
+
'result_diversity': best_config['result_diversity'],
|
631 |
+
'rank_correlation': best_config['rank_correlation'],
|
632 |
+
'silhouette_score': best_config['silhouette_score']
|
633 |
+
}
|
634 |
+
}
|
635 |
+
|
636 |
+
return recommendations
|
637 |
+
####
|
638 |
+
|
639 |
# Gradio Interface
|
640 |
def launch_interface(share=True):
|
641 |
with gr.Blocks() as iface:
|
|
|
697 |
outputs=[results_output, stats_output, plot_output]
|
698 |
)
|
699 |
|
700 |
+
####
|
701 |
+
with gr.Tab("Automated"):
|
702 |
+
auto_file_input = gr.File(label="Upload File (Optional)")
|
703 |
+
auto_query_input = gr.Textbox(label="Search Query")
|
704 |
+
auto_model_types = gr.CheckboxGroup(
|
705 |
+
choices=["HuggingFace", "OpenAI", "Cohere"],
|
706 |
+
label="Model Types to Test"
|
707 |
+
)
|
708 |
+
auto_model_names = gr.TextArea(label="Model Names to Test (comma-separated)")
|
709 |
+
auto_split_strategies = gr.CheckboxGroup(
|
710 |
+
choices=["token", "recursive"],
|
711 |
+
label="Split Strategies to Test"
|
712 |
+
)
|
713 |
+
auto_chunk_sizes = gr.TextArea(label="Chunk Sizes to Test (comma-separated)")
|
714 |
+
auto_overlap_sizes = gr.TextArea(label="Overlap Sizes to Test (comma-separated)")
|
715 |
+
auto_vector_store_types = gr.CheckboxGroup(
|
716 |
+
choices=["FAISS", "Chroma"],
|
717 |
+
label="Vector Store Types to Test"
|
718 |
+
)
|
719 |
+
auto_search_types = gr.CheckboxGroup(
|
720 |
+
choices=["similarity", "mmr", "custom"],
|
721 |
+
label="Search Types to Test"
|
722 |
+
)
|
723 |
+
auto_top_k = gr.TextArea(label="Top K Values to Test (comma-separated)")
|
724 |
+
auto_optimize_vocab = gr.Checkbox(label="Test Vocabulary Optimization", value=True)
|
725 |
+
auto_use_query_optimization = gr.Checkbox(label="Test Query Optimization", value=True)
|
726 |
+
auto_use_reranking = gr.Checkbox(label="Test Reranking", value=True)
|
727 |
+
|
728 |
+
auto_results_output = gr.Dataframe(label="Automated Test Results", interactive=False)
|
729 |
+
auto_stats_output = gr.Dataframe(label="Automated Test Statistics", interactive=False)
|
730 |
+
recommendations_output = gr.JSON(label="Recommendations")
|
731 |
+
|
732 |
+
auto_submit_button = gr.Button("Run Automated Tests")
|
733 |
+
auto_submit_button.click(
|
734 |
+
fn=lambda *args: run_automated_tests_and_analyze(*args),
|
735 |
+
inputs=[
|
736 |
+
auto_file_input, auto_query_input, auto_model_types, auto_model_names,
|
737 |
+
auto_split_strategies, auto_chunk_sizes, auto_overlap_sizes,
|
738 |
+
auto_vector_store_types, auto_search_types, auto_top_k,
|
739 |
+
auto_optimize_vocab, auto_use_query_optimization, auto_use_reranking
|
740 |
+
],
|
741 |
+
outputs=[auto_results_output, auto_stats_output, recommendations_output]
|
742 |
+
)
|
743 |
+
###
|
744 |
+
|
745 |
|
746 |
tutorial_md = """
|
747 |
# Advanced Embedding Comparison Tool Tutorial
|
|
|
768 |
|
769 |
iface.launch(share=share)
|
770 |
|
771 |
+
def run_automated_tests_and_analyze(*args):
|
772 |
+
file, query, model_types, model_names, split_strategies, chunk_sizes, overlap_sizes, \
|
773 |
+
vector_store_types, search_types, top_k_values, optimize_vocab, use_query_optimization, use_reranking = args
|
774 |
+
|
775 |
+
test_params = {
|
776 |
+
'model_type': model_types,
|
777 |
+
'model_name': [name.strip() for name in model_names.split(',')],
|
778 |
+
'split_strategy': split_strategies,
|
779 |
+
'chunk_size': [int(size.strip()) for size in chunk_sizes.split(',')],
|
780 |
+
'overlap_size': [int(size.strip()) for size in overlap_sizes.split(',')],
|
781 |
+
'vector_store_type': vector_store_types,
|
782 |
+
'search_type': search_types,
|
783 |
+
'top_k': [int(k.strip()) for k in top_k_values.split(',')],
|
784 |
+
'lang': ['german'], # You can expand this if needed
|
785 |
+
'apply_preprocessing': [True],
|
786 |
+
'optimize_vocab': [optimize_vocab],
|
787 |
+
'apply_phonetic': [True],
|
788 |
+
'phonetic_weight': [0.3],
|
789 |
+
'use_query_optimization': [use_query_optimization],
|
790 |
+
'query_optimization_model': ['google/flan-t5-base'],
|
791 |
+
'use_reranking': [use_reranking]
|
792 |
+
}
|
793 |
+
|
794 |
+
results_df, stats_df = automated_testing(file, query, test_params)
|
795 |
+
recommendations = analyze_results(stats_df)
|
796 |
+
|
797 |
+
return results_df, stats_df, recommendations
|
798 |
+
|
799 |
if __name__ == "__main__":
|
800 |
launch_interface()
|