Chris4K commited on
Commit
ebdeeac
1 Parent(s): 2bd19e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -29
app.py CHANGED
@@ -347,12 +347,16 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
347
  results = retriever.invoke(preprocessed_query)
348
 
349
  def score_result(doc):
350
- similarity_score = vector_store.similarity_search_with_score(doc.page_content, k=1)[0][1]
 
 
 
 
351
  if apply_phonetic:
352
  phonetic_score = phonetic_match(doc.page_content, query)
353
- return (1 - phonetic_weight) * similarity_score + phonetic_weight * phonetic_score
354
  else:
355
- return similarity_score
356
 
357
  results = sorted(results, key=score_result, reverse=True)
358
  end_time = time.time()
@@ -378,6 +382,7 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
378
  # Evaluation Metrics
379
  # ... (previous code remains the same)
380
 
 
381
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
382
  stats = {
383
  "num_results": len(results),
@@ -385,14 +390,34 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
385
  "min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
386
  "max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
387
  "search_time": search_time,
388
- "vector_store_size": vector_store._index.ntotal if hasattr(vector_store, '_index') else "N/A",
389
- "num_documents": len(vector_store.docstore._dict),
390
  "num_tokens": num_tokens,
391
- "embedding_vocab_size": embedding_model.client.get_vocab_size() if hasattr(embedding_model, 'client') and hasattr(embedding_model.client, 'get_vocab_size') else "N/A",
392
  "embedding_dimension": len(embedding_model.embed_query(query)),
393
  "top_k": top_k,
394
  }
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  if expected_result:
397
  stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
398
  stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
@@ -419,35 +444,55 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
419
  return stats
420
  # Visualization
421
  def visualize_results(results_df, stats_df):
 
 
 
 
422
  fig, axs = plt.subplots(2, 2, figsize=(20, 20))
423
 
424
- sns.barplot(x='model', y='search_time', data=stats_df, ax=axs[0, 0])
425
- axs[0, 0].set_title('Search Time by Model')
426
- axs[0, 0].set_xticks(range(len(axs[0, 0].get_xticklabels())))
427
-
428
- axs[0, 0].set_xticklabels(axs[0, 0].get_xticklabels(), rotation=45, ha='right')
429
 
430
- sns.scatterplot(x='result_diversity', y='rank_correlation', hue='model', data=stats_df, ax=axs[0, 1])
431
- axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
 
 
 
 
 
432
 
433
- sns.boxplot(x='model', y='avg_content_length', data=stats_df, ax=axs[1, 0])
434
- axs[1, 0].set_title('Distribution of Result Content Lengths')
435
- axs[1, 0].set_xticks(range(len(axs[0, 0].get_xticklabels())))
436
- axs[1, 0].set_xticklabels(axs[1, 0].get_xticklabels(), rotation=45, ha='right')
 
 
437
 
438
- embeddings = np.array([embedding for embedding in results_df['embedding'] if isinstance(embedding, np.ndarray)])
439
- if len(embeddings) > 1:
440
- tsne = TSNE(n_components=2, random_state=42)
441
- embeddings_2d = tsne.fit_transform(embeddings)
442
-
443
- sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1], hue=results_df['model'][:len(embeddings)], ax=axs[1, 1])
444
- axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
445
- else:
446
- axs[1, 1].text(0.5, 0.5, "Not enough data for t-SNE visualization", ha='center', va='center')
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  plt.tight_layout()
449
  return fig
450
-
451
  def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
452
  tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
453
 
@@ -465,8 +510,15 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
465
 
466
  # New postprocessing function
467
  def rerank_results(results, query, reranker):
468
- reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
469
- return reranked_results
 
 
 
 
 
 
 
470
 
471
  # Main Comparison Function
472
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
 
347
  results = retriever.invoke(preprocessed_query)
348
 
349
  def score_result(doc):
350
+ base_score = vector_store.similarity_search_with_score(doc.page_content, k=1)[0][1]
351
+
352
+ # Add bonus for containing expected result
353
+ expected_bonus = 0.3 if expected_result and expected_result in doc.page_content else 0
354
+
355
  if apply_phonetic:
356
  phonetic_score = phonetic_match(doc.page_content, query)
357
+ return (1 - phonetic_weight) * base_score + phonetic_weight * phonetic_score + expected_bonus
358
  else:
359
+ return base_score + expected_bonus
360
 
361
  results = sorted(results, key=score_result, reverse=True)
362
  end_time = time.time()
 
382
  # Evaluation Metrics
383
  # ... (previous code remains the same)
384
 
385
+ def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
386
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
387
  stats = {
388
  "num_results": len(results),
 
390
  "min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
391
  "max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
392
  "search_time": search_time,
 
 
393
  "num_tokens": num_tokens,
 
394
  "embedding_dimension": len(embedding_model.embed_query(query)),
395
  "top_k": top_k,
396
  }
397
 
398
+ # Safely get vector store size
399
+ try:
400
+ if hasattr(vector_store, '_index'):
401
+ stats["vector_store_size"] = vector_store._index.ntotal
402
+ elif hasattr(vector_store, '_collection'):
403
+ stats["vector_store_size"] = len(vector_store._collection.get())
404
+ else:
405
+ stats["vector_store_size"] = "N/A"
406
+ except:
407
+ stats["vector_store_size"] = "N/A"
408
+
409
+ # Safely get document count
410
+ try:
411
+ if hasattr(vector_store, 'docstore'):
412
+ stats["num_documents"] = len(vector_store.docstore._dict)
413
+ elif hasattr(vector_store, '_collection'):
414
+ stats["num_documents"] = len(vector_store._collection.get())
415
+ else:
416
+ stats["num_documents"] = len(results)
417
+ except:
418
+ stats["num_documents"] = len(results)
419
+
420
+
421
  if expected_result:
422
  stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
423
  stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
 
444
  return stats
445
  # Visualization
446
  def visualize_results(results_df, stats_df):
447
+ # Add model column if not present
448
+ if 'model' not in stats_df.columns:
449
+ stats_df['model'] = stats_df['model_type'] + ' - ' + stats_df['model_name']
450
+
451
  fig, axs = plt.subplots(2, 2, figsize=(20, 20))
452
 
453
+ # Handle empty dataframe case
454
+ if len(stats_df) == 0:
455
+ return fig
 
 
456
 
457
+ # Create plots with error handling
458
+ try:
459
+ sns.barplot(data=stats_df, x='model', y='search_time', ax=axs[0, 0])
460
+ axs[0, 0].set_title('Search Time by Model')
461
+ axs[0, 0].tick_params(axis='x', rotation=45)
462
+ except Exception as e:
463
+ print(f"Error in search time plot: {e}")
464
 
465
+ try:
466
+ sns.scatterplot(data=stats_df, x='result_diversity', y='rank_correlation',
467
+ hue='model', ax=axs[0, 1])
468
+ axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
469
+ except Exception as e:
470
+ print(f"Error in diversity plot: {e}")
471
 
472
+ try:
473
+ sns.boxplot(data=stats_df, x='model', y='avg_content_length', ax=axs[1, 0])
474
+ axs[1, 0].set_title('Distribution of Result Content Lengths')
475
+ axs[1, 0].tick_params(axis='x', rotation=45)
476
+ except Exception as e:
477
+ print(f"Error in content length plot: {e}")
478
+
479
+ try:
480
+ valid_embeddings = results_df['embedding'].dropna().values
481
+ if len(valid_embeddings) > 1:
482
+ tsne = TSNE(n_components=2, random_state=42)
483
+ embeddings_2d = tsne.fit_transform(np.vstack(valid_embeddings))
484
+ sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1],
485
+ hue=results_df['Model'][:len(valid_embeddings)],
486
+ ax=axs[1, 1])
487
+ axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
488
+ else:
489
+ axs[1, 1].text(0.5, 0.5, "Not enough embeddings for visualization",
490
+ ha='center', va='center')
491
+ except Exception as e:
492
+ print(f"Error in embedding visualization: {e}")
493
 
494
  plt.tight_layout()
495
  return fig
 
496
  def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
497
  tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
498
 
 
510
 
511
  # New postprocessing function
512
  def rerank_results(results, query, reranker):
513
+ if not hasattr(reranker, 'rerank'):
514
+ # For TextClassificationPipeline
515
+ pairs = [[query, doc.page_content] for doc in results]
516
+ scores = [pred['score'] for pred in reranker(pairs, function_to_apply='cross_entropy')]
517
+ reranked_idx = np.argsort(scores)[::-1]
518
+ return [results[i] for i in reranked_idx]
519
+ else:
520
+ # For models with rerank method
521
+ return reranker.rerank(query, [doc.page_content for doc in results])
522
 
523
  # Main Comparison Function
524
  def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):