Chris4K commited on
Commit
488e992
1 Parent(s): fd47dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -158
app.py CHANGED
@@ -5,38 +5,35 @@ import docx
5
  import nltk
6
  import gradio as gr
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
- from langchain_community.embeddings import (
9
- OpenAIEmbeddings,
10
- CohereEmbeddings,
11
- )
12
  from langchain_openai import OpenAIEmbeddings
13
  from langchain_community.vectorstores import FAISS, Chroma
14
- from langchain_text_splitters import (
15
- RecursiveCharacterTextSplitter,
16
- TokenTextSplitter,
17
- )
18
  from typing import List, Dict, Any
19
  import pandas as pd
 
20
  import re
21
  from nltk.corpus import stopwords
22
  from nltk.tokenize import word_tokenize
23
  from nltk.stem import SnowballStemmer
24
- import jellyfish # For Kölner Phonetik
25
  from gensim.models import Word2Vec
26
  from gensim.models.fasttext import FastText
27
  from collections import Counter
28
  from tokenizers import Tokenizer
29
- from tokenizers.models import BPE
30
- from tokenizers.trainers import BpeTrainer
31
-
32
-
 
 
 
 
 
33
 
 
34
  def download_nltk_resources():
35
- resources = [
36
- 'punkt',
37
- 'stopwords',
38
- 'snowball_data',
39
- ]
40
  for resource in resources:
41
  try:
42
  nltk.download(resource, quiet=True)
@@ -45,47 +42,87 @@ def download_nltk_resources():
45
 
46
  download_nltk_resources()
47
 
 
48
 
49
- nltk.download('stopwords', quiet=True)
50
- nltk.download('punkt', quiet=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- FILES_DIR = './files'
 
 
 
53
 
54
- MODELS = {
55
- 'HuggingFace': {
56
- 'e5-base-de': "danielheinz/e5-base-sts-en-de",
57
- 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
58
- 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
59
- 'gte-large': "gte-large",
60
- 'gbert-base': "gbert-base"
61
- },
62
- 'OpenAI': {
63
- 'text-embedding-ada-002': "text-embedding-ada-002"
64
- },
65
- 'Cohere': {
66
- 'embed-multilingual-v2.0': "embed-multilingual-v2.0"
67
- }
68
- }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def simple_tokenize(text):
71
- """Simple tokenization fallback method."""
72
  return text.split()
73
 
74
  def preprocess_text(text, lang='german'):
75
- # Convert to lowercase
76
  text = text.lower()
77
-
78
- # Remove special characters and digits
79
  text = re.sub(r'[^a-zA-Z\s]', '', text)
80
 
81
- # Tokenize
82
  try:
83
  tokens = word_tokenize(text, language=lang)
84
  except LookupError:
85
  print(f"Warning: NLTK punkt tokenizer for {lang} not found. Using simple tokenization.")
86
  tokens = simple_tokenize(text)
87
 
88
- # Remove stopwords
89
  try:
90
  stop_words = set(stopwords.words(lang))
91
  except LookupError:
@@ -93,7 +130,6 @@ def preprocess_text(text, lang='german'):
93
  stop_words = set()
94
  tokens = [token for token in tokens if token not in stop_words]
95
 
96
- # Stemming
97
  try:
98
  stemmer = SnowballStemmer(lang)
99
  tokens = [stemmer.stem(token) for token in tokens]
@@ -107,44 +143,34 @@ def phonetic_match(text, query, method='koelner_phonetik'):
107
  text_phonetic = jellyfish.cologne_phonetic(text)
108
  query_phonetic = jellyfish.cologne_phonetic(query)
109
  return jellyfish.jaro_winkler(text_phonetic, query_phonetic)
110
- # Add other phonetic methods as needed
111
  return 0
112
 
113
- class FileHandler:
114
- @staticmethod
115
- def extract_text(file_path):
116
- ext = os.path.splitext(file_path)[-1].lower()
117
- if ext == '.pdf':
118
- return FileHandler._extract_from_pdf(file_path)
119
- elif ext == '.docx':
120
- return FileHandler._extract_from_docx(file_path)
121
- elif ext == '.txt':
122
- return FileHandler._extract_from_txt(file_path)
123
- else:
124
- raise ValueError(f"Unsupported file type: {ext}")
125
 
126
- @staticmethod
127
- def _extract_from_pdf(file_path):
128
- with pdfplumber.open(file_path) as pdf:
129
- return ' '.join([page.extract_text() for page in pdf.pages])
130
 
131
- @staticmethod
132
- def _extract_from_docx(file_path):
133
- doc = docx.Document(file_path)
134
- return ' '.join([para.text for para in doc.paragraphs])
135
 
136
- @staticmethod
137
- def _extract_from_txt(file_path):
138
- with open(file_path, 'r', encoding='utf-8') as f:
139
- return f.read()
140
 
 
 
 
 
 
141
  def get_embedding_model(model_type, model_name):
 
142
  if model_type == 'HuggingFace':
143
- return HuggingFaceEmbeddings(model_name=MODELS[model_type][model_name])
144
  elif model_type == 'OpenAI':
145
- return OpenAIEmbeddings(model=MODELS[model_type][model_name])
146
  elif model_type == 'Cohere':
147
- return CohereEmbeddings(model=MODELS[model_type][model_name])
148
  else:
149
  raise ValueError(f"Unsupported model type: {model_type}")
150
 
@@ -160,6 +186,7 @@ def get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separator
160
  else:
161
  raise ValueError(f"Unsupported split strategy: {split_strategy}")
162
 
 
163
  def get_vector_store(vector_store_type, chunks, embedding_model):
164
  if vector_store_type == 'FAISS':
165
  return FAISS.from_texts(chunks, embedding_model)
@@ -179,7 +206,8 @@ def get_retriever(vector_store, search_type, search_kwargs):
179
  else:
180
  raise ValueError(f"Unsupported search type: {search_type}")
181
 
182
- def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators, lang='german'):
 
183
  if file_path:
184
  text = FileHandler.extract_text(file_path)
185
  else:
@@ -188,8 +216,11 @@ def process_files(file_path, model_type, model_name, split_strategy, chunk_size,
188
  file_path = os.path.join(FILES_DIR, file)
189
  text += FileHandler.extract_text(file_path)
190
 
191
- # Preprocess the text
192
- text = preprocess_text(text, lang)
 
 
 
193
 
194
  text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators)
195
  chunks = text_splitter.split_text(text)
@@ -199,7 +230,6 @@ def process_files(file_path, model_type, model_name, split_strategy, chunk_size,
199
  return chunks, embedding_model, len(text.split())
200
 
201
  def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k, lang='german', phonetic_weight=0.3):
202
- # Preprocess the query
203
  preprocessed_query = preprocess_text(query, lang)
204
 
205
  vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
@@ -208,15 +238,18 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
208
  start_time = time.time()
209
  results = retriever.invoke(preprocessed_query)
210
 
211
- # Apply phonetic matching
212
- results = sorted(results, key=lambda x: (1 - phonetic_weight) * vector_store.similarity_search(x.page_content, k=1)[0][1] +
213
- phonetic_weight * phonetic_match(x.page_content, query),
214
- reverse=True)
 
 
215
 
216
  end_time = time.time()
217
 
218
  return results[:top_k], end_time - start_time, vector_store
219
 
 
220
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k):
221
  stats = {
222
  "num_results": len(results),
@@ -230,64 +263,58 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
230
  "top_k": top_k,
231
  }
232
 
233
- # Calculate diversity of results
234
  if len(results) > 1:
235
  embeddings = [embedding_model.embed_query(doc.page_content) for doc in results]
236
- pairwise_similarities = cosine_similarity(embeddings)
237
  stats["result_diversity"] = 1 - np.mean(pairwise_similarities[np.triu_indices(len(embeddings), k=1)])
 
 
 
 
 
 
238
  else:
239
  stats["result_diversity"] = "N/A"
 
240
 
241
- # Calculate rank correlation between embedding similarity and result order
242
  query_embedding = embedding_model.embed_query(query)
243
  result_embeddings = [embedding_model.embed_query(doc.page_content) for doc in results]
244
- similarities = [cosine_similarity([query_embedding], [emb])[0][0] for emb in result_embeddings]
245
  rank_correlation, _ = spearmanr(similarities, range(len(similarities)))
246
  stats["rank_correlation"] = rank_correlation
247
 
248
  return stats
249
 
250
- def create_custom_embedding(texts, model_type='word2vec', vector_size=100, window=5, min_count=1):
251
- # Tokenize the texts
252
- tokenized_texts = [text.split() for text in texts]
253
-
254
- if model_type == 'word2vec':
255
- model = Word2Vec(sentences=tokenized_texts, vector_size=vector_size, window=window, min_count=min_count, workers=4)
256
- elif model_type == 'fasttext':
257
- model = FastText(sentences=tokenized_texts, vector_size=vector_size, window=window, min_count=min_count, workers=4)
258
- else:
259
- raise ValueError("Unsupported model type")
260
-
261
- return model
262
-
263
- class CustomEmbeddings(HuggingFaceEmbeddings):
264
- def __init__(self, model_path):
265
- self.model = Word2Vec.load(model_path) # or FastText.load() for FastText models
266
 
267
- def embed_documents(self, texts):
268
- return [self.model.wv[text.split()] for text in texts]
 
269
 
270
- def embed_query(self, text):
271
- return self.model.wv[text.split()]
272
-
273
- def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
274
- # Count word frequencies
275
- word_freq = Counter(word for text in texts for word in text.split())
276
 
277
- # Remove rare words
278
- optimized_texts = [
279
- ' '.join(word for word in text.split() if word_freq[word] >= min_frequency)
280
- for text in texts
281
- ]
282
 
283
- # Train BPE tokenizer
284
- tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
285
- trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
286
- tokenizer.train_from_iterator(optimized_texts, trainer)
 
 
 
 
 
287
 
288
- return tokenizer, optimized_texts
 
289
 
290
- def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, lang='german', use_custom_embedding=False, optimize_vocab=False, phonetic_weight=0.3):
 
291
  all_results = []
292
  all_stats = []
293
  settings = {
@@ -313,7 +340,8 @@ def compare_embeddings(file, query, model_types, model_names, split_strategy, ch
313
  chunk_size,
314
  overlap_size,
315
  custom_separators.split(',') if custom_separators else None,
316
- lang
 
317
  )
318
 
319
  if use_custom_embedding:
@@ -324,7 +352,6 @@ def compare_embeddings(file, query, model_types, model_names, split_strategy, ch
324
  tokenizer, optimized_chunks = optimize_vocabulary(chunks)
325
  chunks = optimized_chunks
326
 
327
-
328
  results, search_time, vector_store = search_embeddings(
329
  chunks,
330
  embedding_model,
@@ -347,7 +374,10 @@ def compare_embeddings(file, query, model_types, model_names, split_strategy, ch
347
  results_df = pd.DataFrame(all_results)
348
  stats_df = pd.DataFrame(all_stats)
349
 
350
- return results_df, stats_df
 
 
 
351
 
352
  def format_results(results, stats):
353
  formatted_results = []
@@ -355,53 +385,22 @@ def format_results(results, stats):
355
  result = {
356
  "Model": stats["model"],
357
  "Content": doc.page_content,
 
358
  **doc.metadata,
359
  **{k: v for k, v in stats.items() if k not in ["model"]}
360
  }
361
  formatted_results.append(result)
362
  return formatted_results
363
 
364
- import matplotlib.pyplot as plt
365
- import seaborn as sns
366
- from sklearn.manifold import TSNE
367
-
368
- def visualize_results(results_df, stats_df):
369
- # Create a figure with subplots
370
- fig, axs = plt.subplots(2, 2, figsize=(20, 20))
371
-
372
- # 1. Bar plot of search times
373
- sns.barplot(x='model', y='search_time', data=stats_df, ax=axs[0, 0])
374
- axs[0, 0].set_title('Search Time by Model')
375
- axs[0, 0].set_xticklabels(axs[0, 0].get_xticklabels(), rotation=45, ha='right')
376
-
377
- # 2. Scatter plot of result diversity vs. rank correlation
378
- sns.scatterplot(x='result_diversity', y='rank_correlation', hue='model', data=stats_df, ax=axs[0, 1])
379
- axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
380
-
381
- # 3. Box plot of content lengths
382
- sns.boxplot(x='model', y='content_length', data=results_df, ax=axs[1, 0])
383
- axs[1, 0].set_title('Distribution of Result Content Lengths')
384
- axs[1, 0].set_xticklabels(axs[1, 0].get_xticklabels(), rotation=45, ha='right')
385
-
386
- # 4. t-SNE visualization of embeddings
387
- embeddings = np.array(results_df['embedding'].tolist())
388
- tsne = TSNE(n_components=2, random_state=42)
389
- embeddings_2d = tsne.fit_transform(embeddings)
390
-
391
- sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1], hue=results_df['model'], ax=axs[1, 1])
392
- axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
393
-
394
- plt.tight_layout()
395
- return fig
396
-
397
  def launch_interface(share=True):
398
  iface = gr.Interface(
399
  fn=compare_embeddings,
400
  inputs=[
401
  gr.File(label="Upload File (Optional)"),
402
  gr.Textbox(label="Search Query"),
403
- gr.CheckboxGroup(choices=list(MODELS.keys()) + ["Custom"], label="Embedding Model Types"),
404
- gr.CheckboxGroup(choices=[model for models in MODELS.values() for model in models] + ["custom_model"], label="Embedding Models"),
405
  gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"),
406
  gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
407
  gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
@@ -412,7 +411,8 @@ def launch_interface(share=True):
412
  gr.Dropdown(choices=["german", "english", "french"], label="Language", value="german"),
413
  gr.Checkbox(label="Use Custom Embedding", value=False),
414
  gr.Checkbox(label="Optimize Vocabulary", value=False),
415
- gr.Slider(0, 1, step=0.1, value=0.3, label="Phonetic Matching Weight")
 
416
  ],
417
  outputs=[
418
  gr.Dataframe(label="Results", interactive=False),
@@ -426,7 +426,21 @@ def launch_interface(share=True):
426
  tutorial_md = """
427
  # Advanced Embedding Comparison Tool Tutorial
428
 
429
- ... (update the tutorial to include information about the new features) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  """
431
 
432
  iface = gr.TabbedInterface(
 
5
  import nltk
6
  import gradio as gr
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.embeddings import CohereEmbeddings
 
 
 
9
  from langchain_openai import OpenAIEmbeddings
10
  from langchain_community.vectorstores import FAISS, Chroma
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
 
 
 
12
  from typing import List, Dict, Any
13
  import pandas as pd
14
+ import numpy as np
15
  import re
16
  from nltk.corpus import stopwords
17
  from nltk.tokenize import word_tokenize
18
  from nltk.stem import SnowballStemmer
19
+ import jellyfish
20
  from gensim.models import Word2Vec
21
  from gensim.models.fasttext import FastText
22
  from collections import Counter
23
  from tokenizers import Tokenizer
24
+ from tokenizers.models import WordLevel
25
+ from tokenizers.trainers import WordLevelTrainer
26
+ from tokenizers.pre_tokenizers import Whitespace
27
+ import matplotlib.pyplot as plt
28
+ import seaborn as sns
29
+ from sklearn.manifold import TSNE
30
+ from sklearn.metrics import silhouette_score
31
+ from scipy.stats import spearmanr
32
+ from functools import lru_cache
33
 
34
+ # NLTK Resource Download
35
  def download_nltk_resources():
36
+ resources = ['punkt', 'stopwords', 'snowball_data']
 
 
 
 
37
  for resource in resources:
38
  try:
39
  nltk.download(resource, quiet=True)
 
42
 
43
  download_nltk_resources()
44
 
45
+ FILES_DIR = './files'
46
 
47
+ # Model Management
48
+ class ModelManager:
49
+ def __init__(self):
50
+ self.models = {
51
+ 'HuggingFace': {
52
+ 'e5-base-de': "danielheinz/e5-base-sts-en-de",
53
+ 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
54
+ 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
55
+ 'gte-large': "gte-large",
56
+ 'gbert-base': "gbert-base"
57
+ },
58
+ 'OpenAI': {
59
+ 'text-embedding-ada-002': "text-embedding-ada-002"
60
+ },
61
+ 'Cohere': {
62
+ 'embed-multilingual-v2.0': "embed-multilingual-v2.0"
63
+ }
64
+ }
65
 
66
+ def add_model(self, provider, name, model_path):
67
+ if provider not in self.models:
68
+ self.models[provider] = {}
69
+ self.models[provider][name] = model_path
70
 
71
+ def remove_model(self, provider, name):
72
+ if provider in self.models and name in self.models[provider]:
73
+ del self.models[provider][name]
74
+
75
+ def get_model(self, provider, name):
76
+ return self.models.get(provider, {}).get(name)
77
+
78
+ def list_models(self):
79
+ return {provider: list(models.keys()) for provider, models in self.models.items()}
80
+
81
+ model_manager = ModelManager()
 
 
 
 
82
 
83
+ # File Handling
84
+ class FileHandler:
85
+ @staticmethod
86
+ def extract_text(file_path):
87
+ ext = os.path.splitext(file_path)[-1].lower()
88
+ if ext == '.pdf':
89
+ return FileHandler._extract_from_pdf(file_path)
90
+ elif ext == '.docx':
91
+ return FileHandler._extract_from_docx(file_path)
92
+ elif ext == '.txt':
93
+ return FileHandler._extract_from_txt(file_path)
94
+ else:
95
+ raise ValueError(f"Unsupported file type: {ext}")
96
+
97
+ @staticmethod
98
+ def _extract_from_pdf(file_path):
99
+ with pdfplumber.open(file_path) as pdf:
100
+ return ' '.join([page.extract_text() for page in pdf.pages])
101
+
102
+ @staticmethod
103
+ def _extract_from_docx(file_path):
104
+ doc = docx.Document(file_path)
105
+ return ' '.join([para.text for para in doc.paragraphs])
106
+
107
+ @staticmethod
108
+ def _extract_from_txt(file_path):
109
+ with open(file_path, 'r', encoding='utf-8') as f:
110
+ return f.read()
111
+
112
+ # Text Processing
113
  def simple_tokenize(text):
 
114
  return text.split()
115
 
116
  def preprocess_text(text, lang='german'):
 
117
  text = text.lower()
 
 
118
  text = re.sub(r'[^a-zA-Z\s]', '', text)
119
 
 
120
  try:
121
  tokens = word_tokenize(text, language=lang)
122
  except LookupError:
123
  print(f"Warning: NLTK punkt tokenizer for {lang} not found. Using simple tokenization.")
124
  tokens = simple_tokenize(text)
125
 
 
126
  try:
127
  stop_words = set(stopwords.words(lang))
128
  except LookupError:
 
130
  stop_words = set()
131
  tokens = [token for token in tokens if token not in stop_words]
132
 
 
133
  try:
134
  stemmer = SnowballStemmer(lang)
135
  tokens = [stemmer.stem(token) for token in tokens]
 
143
  text_phonetic = jellyfish.cologne_phonetic(text)
144
  query_phonetic = jellyfish.cologne_phonetic(query)
145
  return jellyfish.jaro_winkler(text_phonetic, query_phonetic)
 
146
  return 0
147
 
148
+ # Custom Tokenizer
149
+ def create_custom_tokenizer(file_path):
150
+ with open(file_path, 'r', encoding='utf-8') as f:
151
+ text = f.read()
 
 
 
 
 
 
 
 
152
 
153
+ tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
154
+ tokenizer.pre_tokenizer = Whitespace()
 
 
155
 
156
+ trainer = WordLevelTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
157
+ tokenizer.train_from_iterator([text], trainer)
 
 
158
 
159
+ return tokenizer
 
 
 
160
 
161
+ def custom_tokenize(text, tokenizer):
162
+ return tokenizer.encode(text).tokens
163
+
164
+ # Embedding and Vector Store
165
+ @lru_cache(maxsize=None)
166
  def get_embedding_model(model_type, model_name):
167
+ model_path = model_manager.get_model(model_type, model_name)
168
  if model_type == 'HuggingFace':
169
+ return HuggingFaceEmbeddings(model_name=model_path)
170
  elif model_type == 'OpenAI':
171
+ return OpenAIEmbeddings(model=model_path)
172
  elif model_type == 'Cohere':
173
+ return CohereEmbeddings(model=model_path)
174
  else:
175
  raise ValueError(f"Unsupported model type: {model_type}")
176
 
 
186
  else:
187
  raise ValueError(f"Unsupported split strategy: {split_strategy}")
188
 
189
+ @lru_cache(maxsize=None)
190
  def get_vector_store(vector_store_type, chunks, embedding_model):
191
  if vector_store_type == 'FAISS':
192
  return FAISS.from_texts(chunks, embedding_model)
 
206
  else:
207
  raise ValueError(f"Unsupported search type: {search_type}")
208
 
209
+ # Main Processing Functions
210
+ def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators, lang='german', custom_tokenizer_file=None):
211
  if file_path:
212
  text = FileHandler.extract_text(file_path)
213
  else:
 
216
  file_path = os.path.join(FILES_DIR, file)
217
  text += FileHandler.extract_text(file_path)
218
 
219
+ if custom_tokenizer_file:
220
+ tokenizer = create_custom_tokenizer(custom_tokenizer_file)
221
+ text = ' '.join(custom_tokenize(text, tokenizer))
222
+ else:
223
+ text = preprocess_text(text, lang)
224
 
225
  text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators)
226
  chunks = text_splitter.split_text(text)
 
230
  return chunks, embedding_model, len(text.split())
231
 
232
  def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k, lang='german', phonetic_weight=0.3):
 
233
  preprocessed_query = preprocess_text(query, lang)
234
 
235
  vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
 
238
  start_time = time.time()
239
  results = retriever.invoke(preprocessed_query)
240
 
241
+ def score_result(doc):
242
+ similarity_score = vector_store.similarity_search_with_score(doc.page_content, k=1)[0][1]
243
+ phonetic_score = phonetic_match(doc.page_content, query)
244
+ return (1 - phonetic_weight) * similarity_score + phonetic_weight * phonetic_score
245
+
246
+ results = sorted(results, key=score_result, reverse=True)
247
 
248
  end_time = time.time()
249
 
250
  return results[:top_k], end_time - start_time, vector_store
251
 
252
+ # Evaluation Metrics
253
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k):
254
  stats = {
255
  "num_results": len(results),
 
263
  "top_k": top_k,
264
  }
265
 
 
266
  if len(results) > 1:
267
  embeddings = [embedding_model.embed_query(doc.page_content) for doc in results]
268
+ pairwise_similarities = np.inner(embeddings, embeddings)
269
  stats["result_diversity"] = 1 - np.mean(pairwise_similarities[np.triu_indices(len(embeddings), k=1)])
270
+
271
+ # Silhouette Score
272
+ if len(embeddings) > 2:
273
+ stats["silhouette_score"] = silhouette_score(embeddings, range(len(embeddings)))
274
+ else:
275
+ stats["silhouette_score"] = "N/A"
276
  else:
277
  stats["result_diversity"] = "N/A"
278
+ stats["silhouette_score"] = "N/A"
279
 
 
280
  query_embedding = embedding_model.embed_query(query)
281
  result_embeddings = [embedding_model.embed_query(doc.page_content) for doc in results]
282
+ similarities = [np.inner(query_embedding, emb)[0] for emb in result_embeddings]
283
  rank_correlation, _ = spearmanr(similarities, range(len(similarities)))
284
  stats["rank_correlation"] = rank_correlation
285
 
286
  return stats
287
 
288
+ # Visualization
289
+ def visualize_results(results_df, stats_df):
290
+ fig, axs = plt.subplots(2, 2, figsize=(20, 20))
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
+ sns.barplot(x='model', y='search_time', data=stats_df, ax=axs[0, 0])
293
+ axs[0, 0].set_title('Search Time by Model')
294
+ axs[0, 0].set_xticklabels(axs[0, 0].get_xticklabels(), rotation=45, ha='right')
295
 
296
+ sns.scatterplot(x='result_diversity', y='rank_correlation', hue='model', data=stats_df, ax=axs[0, 1])
297
+ axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
 
 
 
 
298
 
299
+ sns.boxplot(x='model', y='avg_content_length', data=stats_df, ax=axs[1, 0])
300
+ axs[1, 0].set_title('Distribution of Result Content Lengths')
301
+ axs[1, 0].set_xticklabels(axs[1, 0].get_xticklabels(), rotation=45, ha='right')
 
 
302
 
303
+ embeddings = np.array([embedding for embedding in results_df['embedding'] if isinstance(embedding, np.ndarray)])
304
+ if len(embeddings) > 1:
305
+ tsne = TSNE(n_components=2, random_state=42)
306
+ embeddings_2d = tsne.fit_transform(embeddings)
307
+
308
+ sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1], hue=results_df['model'][:len(embeddings)], ax=axs[1, 1])
309
+ axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
310
+ else:
311
+ axs[1, 1].text(0.5, 0.5, "Not enough data for t-SNE visualization", ha='center', va='center')
312
 
313
+ plt.tight_layout()
314
+ return fig
315
 
316
+ # Main Comparison Function
317
+ def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, lang='german', use_custom_embedding=False, optimize_vocab=False, phonetic_weight=0.3, custom_tokenizer_file=None):
318
  all_results = []
319
  all_stats = []
320
  settings = {
 
340
  chunk_size,
341
  overlap_size,
342
  custom_separators.split(',') if custom_separators else None,
343
+ lang,
344
+ custom_tokenizer_file
345
  )
346
 
347
  if use_custom_embedding:
 
352
  tokenizer, optimized_chunks = optimize_vocabulary(chunks)
353
  chunks = optimized_chunks
354
 
 
355
  results, search_time, vector_store = search_embeddings(
356
  chunks,
357
  embedding_model,
 
374
  results_df = pd.DataFrame(all_results)
375
  stats_df = pd.DataFrame(all_stats)
376
 
377
+ # Generate visualizations
378
+ fig = visualize_results(results_df, stats_df)
379
+
380
+ return results_df, stats_df, fig
381
 
382
  def format_results(results, stats):
383
  formatted_results = []
 
385
  result = {
386
  "Model": stats["model"],
387
  "Content": doc.page_content,
388
+ "Embedding": doc.embedding if hasattr(doc, 'embedding') else None,
389
  **doc.metadata,
390
  **{k: v for k, v in stats.items() if k not in ["model"]}
391
  }
392
  formatted_results.append(result)
393
  return formatted_results
394
 
395
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  def launch_interface(share=True):
397
  iface = gr.Interface(
398
  fn=compare_embeddings,
399
  inputs=[
400
  gr.File(label="Upload File (Optional)"),
401
  gr.Textbox(label="Search Query"),
402
+ gr.CheckboxGroup(choices=list(model_manager.list_models().keys()) + ["Custom"], label="Embedding Model Types"),
403
+ gr.CheckboxGroup(choices=[model for models in model_manager.list_models().values() for model in models] + ["custom_model"], label="Embedding Models"),
404
  gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"),
405
  gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
406
  gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
 
411
  gr.Dropdown(choices=["german", "english", "french"], label="Language", value="german"),
412
  gr.Checkbox(label="Use Custom Embedding", value=False),
413
  gr.Checkbox(label="Optimize Vocabulary", value=False),
414
+ gr.Slider(0, 1, step=0.1, value=0.3, label="Phonetic Matching Weight"),
415
+ gr.File(label="Custom Tokenizer File (Optional)")
416
  ],
417
  outputs=[
418
  gr.Dataframe(label="Results", interactive=False),
 
426
  tutorial_md = """
427
  # Advanced Embedding Comparison Tool Tutorial
428
 
429
+ This tool allows you to compare different embedding models and retrieval strategies for document search and similarity matching.
430
+
431
+ ## How to use:
432
+
433
+ 1. Upload a file (optional) or use the default files in the system.
434
+ 2. Enter a search query.
435
+ 3. Select one or more embedding model types and specific models.
436
+ 4. Choose a text splitting strategy and set chunk size and overlap.
437
+ 5. Select a vector store type and search type.
438
+ 6. Set the number of top results to retrieve.
439
+ 7. Choose the language of your documents.
440
+ 8. Optionally, use custom embeddings, optimize vocabulary, or adjust phonetic matching weight.
441
+ 9. If you have a custom tokenizer, upload the file.
442
+
443
+ The tool will process your query and display results, statistics, and visualizations to help you compare the performance of different models and strategies.
444
  """
445
 
446
  iface = gr.TabbedInterface(