Update app.py
Browse files
app.py
CHANGED
@@ -347,12 +347,16 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
|
|
347 |
results = retriever.invoke(preprocessed_query)
|
348 |
|
349 |
def score_result(doc):
|
350 |
-
|
|
|
|
|
|
|
|
|
351 |
if apply_phonetic:
|
352 |
phonetic_score = phonetic_match(doc.page_content, query)
|
353 |
-
return (1 - phonetic_weight) *
|
354 |
else:
|
355 |
-
return
|
356 |
|
357 |
results = sorted(results, key=score_result, reverse=True)
|
358 |
end_time = time.time()
|
@@ -378,6 +382,7 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
|
|
378 |
# Evaluation Metrics
|
379 |
# ... (previous code remains the same)
|
380 |
|
|
|
381 |
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
382 |
stats = {
|
383 |
"num_results": len(results),
|
@@ -385,14 +390,34 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
|
|
385 |
"min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
|
386 |
"max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
|
387 |
"search_time": search_time,
|
388 |
-
"vector_store_size": vector_store._index.ntotal if hasattr(vector_store, '_index') else "N/A",
|
389 |
-
"num_documents": len(vector_store.docstore._dict),
|
390 |
"num_tokens": num_tokens,
|
391 |
-
"embedding_vocab_size": embedding_model.client.get_vocab_size() if hasattr(embedding_model, 'client') and hasattr(embedding_model.client, 'get_vocab_size') else "N/A",
|
392 |
"embedding_dimension": len(embedding_model.embed_query(query)),
|
393 |
"top_k": top_k,
|
394 |
}
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
if expected_result:
|
397 |
stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
|
398 |
stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
|
@@ -419,35 +444,55 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
|
|
419 |
return stats
|
420 |
# Visualization
|
421 |
def visualize_results(results_df, stats_df):
|
|
|
|
|
|
|
|
|
422 |
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
|
423 |
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
axs[0, 0].set_xticklabels(axs[0, 0].get_xticklabels(), rotation=45, ha='right')
|
429 |
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
|
|
|
|
437 |
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
plt.tight_layout()
|
449 |
return fig
|
450 |
-
|
451 |
def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
452 |
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
453 |
|
@@ -465,8 +510,15 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
|
465 |
|
466 |
# New postprocessing function
|
467 |
def rerank_results(results, query, reranker):
|
468 |
-
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
# Main Comparison Function
|
472 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
|
|
347 |
results = retriever.invoke(preprocessed_query)
|
348 |
|
349 |
def score_result(doc):
|
350 |
+
base_score = vector_store.similarity_search_with_score(doc.page_content, k=1)[0][1]
|
351 |
+
|
352 |
+
# Add bonus for containing expected result
|
353 |
+
expected_bonus = 0.3 if expected_result and expected_result in doc.page_content else 0
|
354 |
+
|
355 |
if apply_phonetic:
|
356 |
phonetic_score = phonetic_match(doc.page_content, query)
|
357 |
+
return (1 - phonetic_weight) * base_score + phonetic_weight * phonetic_score + expected_bonus
|
358 |
else:
|
359 |
+
return base_score + expected_bonus
|
360 |
|
361 |
results = sorted(results, key=score_result, reverse=True)
|
362 |
end_time = time.time()
|
|
|
382 |
# Evaluation Metrics
|
383 |
# ... (previous code remains the same)
|
384 |
|
385 |
+
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
386 |
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
387 |
stats = {
|
388 |
"num_results": len(results),
|
|
|
390 |
"min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
|
391 |
"max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
|
392 |
"search_time": search_time,
|
|
|
|
|
393 |
"num_tokens": num_tokens,
|
|
|
394 |
"embedding_dimension": len(embedding_model.embed_query(query)),
|
395 |
"top_k": top_k,
|
396 |
}
|
397 |
|
398 |
+
# Safely get vector store size
|
399 |
+
try:
|
400 |
+
if hasattr(vector_store, '_index'):
|
401 |
+
stats["vector_store_size"] = vector_store._index.ntotal
|
402 |
+
elif hasattr(vector_store, '_collection'):
|
403 |
+
stats["vector_store_size"] = len(vector_store._collection.get())
|
404 |
+
else:
|
405 |
+
stats["vector_store_size"] = "N/A"
|
406 |
+
except:
|
407 |
+
stats["vector_store_size"] = "N/A"
|
408 |
+
|
409 |
+
# Safely get document count
|
410 |
+
try:
|
411 |
+
if hasattr(vector_store, 'docstore'):
|
412 |
+
stats["num_documents"] = len(vector_store.docstore._dict)
|
413 |
+
elif hasattr(vector_store, '_collection'):
|
414 |
+
stats["num_documents"] = len(vector_store._collection.get())
|
415 |
+
else:
|
416 |
+
stats["num_documents"] = len(results)
|
417 |
+
except:
|
418 |
+
stats["num_documents"] = len(results)
|
419 |
+
|
420 |
+
|
421 |
if expected_result:
|
422 |
stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
|
423 |
stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
|
|
|
444 |
return stats
|
445 |
# Visualization
|
446 |
def visualize_results(results_df, stats_df):
|
447 |
+
# Add model column if not present
|
448 |
+
if 'model' not in stats_df.columns:
|
449 |
+
stats_df['model'] = stats_df['model_type'] + ' - ' + stats_df['model_name']
|
450 |
+
|
451 |
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
|
452 |
|
453 |
+
# Handle empty dataframe case
|
454 |
+
if len(stats_df) == 0:
|
455 |
+
return fig
|
|
|
|
|
456 |
|
457 |
+
# Create plots with error handling
|
458 |
+
try:
|
459 |
+
sns.barplot(data=stats_df, x='model', y='search_time', ax=axs[0, 0])
|
460 |
+
axs[0, 0].set_title('Search Time by Model')
|
461 |
+
axs[0, 0].tick_params(axis='x', rotation=45)
|
462 |
+
except Exception as e:
|
463 |
+
print(f"Error in search time plot: {e}")
|
464 |
|
465 |
+
try:
|
466 |
+
sns.scatterplot(data=stats_df, x='result_diversity', y='rank_correlation',
|
467 |
+
hue='model', ax=axs[0, 1])
|
468 |
+
axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
|
469 |
+
except Exception as e:
|
470 |
+
print(f"Error in diversity plot: {e}")
|
471 |
|
472 |
+
try:
|
473 |
+
sns.boxplot(data=stats_df, x='model', y='avg_content_length', ax=axs[1, 0])
|
474 |
+
axs[1, 0].set_title('Distribution of Result Content Lengths')
|
475 |
+
axs[1, 0].tick_params(axis='x', rotation=45)
|
476 |
+
except Exception as e:
|
477 |
+
print(f"Error in content length plot: {e}")
|
478 |
+
|
479 |
+
try:
|
480 |
+
valid_embeddings = results_df['embedding'].dropna().values
|
481 |
+
if len(valid_embeddings) > 1:
|
482 |
+
tsne = TSNE(n_components=2, random_state=42)
|
483 |
+
embeddings_2d = tsne.fit_transform(np.vstack(valid_embeddings))
|
484 |
+
sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1],
|
485 |
+
hue=results_df['Model'][:len(valid_embeddings)],
|
486 |
+
ax=axs[1, 1])
|
487 |
+
axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
|
488 |
+
else:
|
489 |
+
axs[1, 1].text(0.5, 0.5, "Not enough embeddings for visualization",
|
490 |
+
ha='center', va='center')
|
491 |
+
except Exception as e:
|
492 |
+
print(f"Error in embedding visualization: {e}")
|
493 |
|
494 |
plt.tight_layout()
|
495 |
return fig
|
|
|
496 |
def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
497 |
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
498 |
|
|
|
510 |
|
511 |
# New postprocessing function
|
512 |
def rerank_results(results, query, reranker):
|
513 |
+
if not hasattr(reranker, 'rerank'):
|
514 |
+
# For TextClassificationPipeline
|
515 |
+
pairs = [[query, doc.page_content] for doc in results]
|
516 |
+
scores = [pred['score'] for pred in reranker(pairs, function_to_apply='cross_entropy')]
|
517 |
+
reranked_idx = np.argsort(scores)[::-1]
|
518 |
+
return [results[i] for i in reranked_idx]
|
519 |
+
else:
|
520 |
+
# For models with rerank method
|
521 |
+
return reranker.rerank(query, [doc.page_content for doc in results])
|
522 |
|
523 |
# Main Comparison Function
|
524 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|