Chris4K commited on
Commit
a717449
·
verified ·
1 Parent(s): 7bc6b38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -30
app.py CHANGED
@@ -17,8 +17,19 @@ from langchain_text_splitters import (
17
  )
18
  from typing import List, Dict, Any
19
  import pandas as pd
20
-
21
-
 
 
 
 
 
 
 
 
 
 
 
22
  nltk.download('punkt', quiet=True)
23
 
24
  FILES_DIR = './files'
@@ -39,6 +50,34 @@ MODELS = {
39
  }
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class FileHandler:
43
  @staticmethod
44
  def extract_text(file_path):
@@ -89,23 +128,26 @@ def get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separator
89
  else:
90
  raise ValueError(f"Unsupported split strategy: {split_strategy}")
91
 
92
- def get_vector_store(store_type, texts, embedding_model):
93
- if store_type == 'FAISS':
94
- return FAISS.from_texts(texts, embedding_model)
95
- elif store_type == 'Chroma':
96
- return Chroma.from_texts(texts, embedding_model)
97
  else:
98
- raise ValueError(f"Unsupported vector store type: {store_type}")
99
 
100
- def get_retriever(vector_store, search_type, search_kwargs=None):
101
  if search_type == 'similarity':
102
  return vector_store.as_retriever(search_type="similarity", search_kwargs=search_kwargs)
103
  elif search_type == 'mmr':
104
  return vector_store.as_retriever(search_type="mmr", search_kwargs=search_kwargs)
 
 
 
105
  else:
106
  raise ValueError(f"Unsupported search type: {search_type}")
107
 
108
- def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators):
109
  if file_path:
110
  text = FileHandler.extract_text(file_path)
111
  else:
@@ -113,6 +155,9 @@ def process_files(file_path, model_type, model_name, split_strategy, chunk_size,
113
  for file in os.listdir(FILES_DIR):
114
  file_path = os.path.join(FILES_DIR, file)
115
  text += FileHandler.extract_text(file_path)
 
 
 
116
 
117
  text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators)
118
  chunks = text_splitter.split_text(text)
@@ -121,15 +166,24 @@ def process_files(file_path, model_type, model_name, split_strategy, chunk_size,
121
 
122
  return chunks, embedding_model, len(text.split())
123
 
124
- def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k):
 
 
 
125
  vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
126
  retriever = get_retriever(vector_store, search_type, {"k": top_k})
127
 
128
  start_time = time.time()
129
- results = retriever.get_relevant_documents(query)
 
 
 
 
 
 
130
  end_time = time.time()
131
 
132
- return results, end_time - start_time, vector_store
133
 
134
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model):
135
  return {
@@ -142,7 +196,47 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
142
  "embedding_vocab_size": embedding_model.client.get_vocab_size() if hasattr(embedding_model, 'client') and hasattr(embedding_model.client, 'get_vocab_size') else "N/A"
143
  }
144
 
145
- def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  all_results = []
147
  all_stats = []
148
  settings = {
@@ -152,7 +246,11 @@ def compare_embeddings(file, query, model_types, model_names, split_strategy, ch
152
  "custom_separators": custom_separators,
153
  "vector_store_type": vector_store_type,
154
  "search_type": search_type,
155
- "top_k": top_k
 
 
 
 
156
  }
157
 
158
  for model_type, model_name in zip(model_types, model_names):
@@ -163,16 +261,27 @@ def compare_embeddings(file, query, model_types, model_names, split_strategy, ch
163
  split_strategy,
164
  chunk_size,
165
  overlap_size,
166
- custom_separators.split(',') if custom_separators else None
 
167
  )
168
 
 
 
 
 
 
 
 
 
169
  results, search_time, vector_store = search_embeddings(
170
  chunks,
171
  embedding_model,
172
  vector_store_type,
173
  search_type,
174
  query,
175
- top_k
 
 
176
  )
177
 
178
  stats = calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model)
@@ -200,39 +309,38 @@ def format_results(results, stats):
200
  formatted_results.append(result)
201
  return formatted_results
202
 
203
- # Gradio interface
204
  def launch_interface(share=True):
205
  iface = gr.Interface(
206
  fn=compare_embeddings,
207
  inputs=[
208
  gr.File(label="Upload File (Optional)"),
209
  gr.Textbox(label="Search Query"),
210
- gr.CheckboxGroup(choices=list(MODELS.keys()), label="Embedding Model Types", value=["HuggingFace"]),
211
- gr.CheckboxGroup(choices=[model for models in MODELS.values() for model in models], label="Embedding Models", value=["e5-base-de"]),
212
  gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"),
213
  gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
214
  gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
215
  gr.Textbox(label="Custom Split Separators (comma-separated, optional)"),
216
  gr.Radio(choices=["FAISS", "Chroma"], label="Vector Store Type", value="FAISS"),
217
- gr.Radio(choices=["similarity", "mmr"], label="Search Type", value="similarity"),
218
- gr.Slider(1, 10, step=1, value=5, label="Top K")
 
 
 
 
219
  ],
220
  outputs=[
221
  gr.Dataframe(label="Results", interactive=False),
222
  gr.Dataframe(label="Statistics", interactive=False)
223
  ],
224
- title="Embedding Comparison Tool",
225
- description="Compare different embedding models and retrieval strategies",
226
- examples=[
227
- ["files/test.txt", "What is machine learning?", ["HuggingFace"], ["e5-base-de"], "recursive", 500, 50, "", "FAISS", "similarity", 5]
228
- ],
229
- allow_flagging="never"
230
  )
231
 
232
  tutorial_md = """
233
- # Embedding Comparison Tool Tutorial
234
 
235
- ... (tutorial content remains the same) ...
236
  """
237
 
238
  iface = gr.TabbedInterface(
 
17
  )
18
  from typing import List, Dict, Any
19
  import pandas as pd
20
+ import re
21
+ from nltk.corpus import stopwords
22
+ from nltk.tokenize import word_tokenize
23
+ from nltk.stem import SnowballStemmer
24
+ import jellyfish # For Kölner Phonetik
25
+ from gensim.models import Word2Vec
26
+ from gensim.models.fasttext import FastText
27
+ from collections import Counter
28
+ from tokenizers import Tokenizer
29
+ from tokenizers.models import BPE
30
+ from tokenizers.trainers import BpeTrainer
31
+
32
+ nltk.download('stopwords', quiet=True)
33
  nltk.download('punkt', quiet=True)
34
 
35
  FILES_DIR = './files'
 
50
  }
51
  }
52
 
53
+ def preprocess_text(text, lang='german'):
54
+ # Convert to lowercase
55
+ text = text.lower()
56
+
57
+ # Remove special characters and digits
58
+ text = re.sub(r'[^a-zA-Z\s]', '', text)
59
+
60
+ # Tokenize
61
+ tokens = word_tokenize(text, language=lang)
62
+
63
+ # Remove stopwords
64
+ stop_words = set(stopwords.words(lang))
65
+ tokens = [token for token in tokens if token not in stop_words]
66
+
67
+ # Stemming
68
+ stemmer = SnowballStemmer(lang)
69
+ tokens = [stemmer.stem(token) for token in tokens]
70
+
71
+ return ' '.join(tokens)
72
+
73
+ def phonetic_match(text, query, method='koelner_phonetik'):
74
+ if method == 'koelner_phonetik':
75
+ text_phonetic = jellyfish.cologne_phonetic(text)
76
+ query_phonetic = jellyfish.cologne_phonetic(query)
77
+ return jellyfish.jaro_winkler(text_phonetic, query_phonetic)
78
+ # Add other phonetic methods as needed
79
+ return 0
80
+
81
  class FileHandler:
82
  @staticmethod
83
  def extract_text(file_path):
 
128
  else:
129
  raise ValueError(f"Unsupported split strategy: {split_strategy}")
130
 
131
+ def get_vector_store(vector_store_type, chunks, embedding_model):
132
+ if vector_store_type == 'FAISS':
133
+ return FAISS.from_texts(chunks, embedding_model)
134
+ elif vector_store_type == 'Chroma':
135
+ return Chroma.from_texts(chunks, embedding_model)
136
  else:
137
+ raise ValueError(f"Unsupported vector store type: {vector_store_type}")
138
 
139
+ def get_retriever(vector_store, search_type, search_kwargs):
140
  if search_type == 'similarity':
141
  return vector_store.as_retriever(search_type="similarity", search_kwargs=search_kwargs)
142
  elif search_type == 'mmr':
143
  return vector_store.as_retriever(search_type="mmr", search_kwargs=search_kwargs)
144
+ elif search_type == 'custom':
145
+ # Implement custom retriever logic here
146
+ pass
147
  else:
148
  raise ValueError(f"Unsupported search type: {search_type}")
149
 
150
+ def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators, lang='german'):
151
  if file_path:
152
  text = FileHandler.extract_text(file_path)
153
  else:
 
155
  for file in os.listdir(FILES_DIR):
156
  file_path = os.path.join(FILES_DIR, file)
157
  text += FileHandler.extract_text(file_path)
158
+
159
+ # Preprocess the text
160
+ text = preprocess_text(text, lang)
161
 
162
  text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators)
163
  chunks = text_splitter.split_text(text)
 
166
 
167
  return chunks, embedding_model, len(text.split())
168
 
169
+ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k, lang='german', phonetic_weight=0.3):
170
+ # Preprocess the query
171
+ preprocessed_query = preprocess_text(query, lang)
172
+
173
  vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
174
  retriever = get_retriever(vector_store, search_type, {"k": top_k})
175
 
176
  start_time = time.time()
177
+ results = retriever.get_relevant_documents(preprocessed_query)
178
+
179
+ # Apply phonetic matching
180
+ results = sorted(results, key=lambda x: (1 - phonetic_weight) * vector_store.similarity_search(x.page_content, k=1)[0][1] +
181
+ phonetic_weight * phonetic_match(x.page_content, query),
182
+ reverse=True)
183
+
184
  end_time = time.time()
185
 
186
+ return results[:top_k], end_time - start_time, vector_store
187
 
188
  def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model):
189
  return {
 
196
  "embedding_vocab_size": embedding_model.client.get_vocab_size() if hasattr(embedding_model, 'client') and hasattr(embedding_model.client, 'get_vocab_size') else "N/A"
197
  }
198
 
199
+ def create_custom_embedding(texts, model_type='word2vec', vector_size=100, window=5, min_count=1):
200
+ # Tokenize the texts
201
+ tokenized_texts = [text.split() for text in texts]
202
+
203
+ if model_type == 'word2vec':
204
+ model = Word2Vec(sentences=tokenized_texts, vector_size=vector_size, window=window, min_count=min_count, workers=4)
205
+ elif model_type == 'fasttext':
206
+ model = FastText(sentences=tokenized_texts, vector_size=vector_size, window=window, min_count=min_count, workers=4)
207
+ else:
208
+ raise ValueError("Unsupported model type")
209
+
210
+ return model
211
+
212
+ class CustomEmbeddings(HuggingFaceEmbeddings):
213
+ def __init__(self, model_path):
214
+ self.model = Word2Vec.load(model_path) # or FastText.load() for FastText models
215
+
216
+ def embed_documents(self, texts):
217
+ return [self.model.wv[text.split()] for text in texts]
218
+
219
+ def embed_query(self, text):
220
+ return self.model.wv[text.split()]
221
+
222
+ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
223
+ # Count word frequencies
224
+ word_freq = Counter(word for text in texts for word in text.split())
225
+
226
+ # Remove rare words
227
+ optimized_texts = [
228
+ ' '.join(word for word in text.split() if word_freq[word] >= min_frequency)
229
+ for text in texts
230
+ ]
231
+
232
+ # Train BPE tokenizer
233
+ tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
234
+ trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
235
+ tokenizer.train_from_iterator(optimized_texts, trainer)
236
+
237
+ return tokenizer, optimized_texts
238
+
239
+ def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, lang, use_custom_embedding, optimize_vocab, phonetic_weight):
240
  all_results = []
241
  all_stats = []
242
  settings = {
 
246
  "custom_separators": custom_separators,
247
  "vector_store_type": vector_store_type,
248
  "search_type": search_type,
249
+ "top_k": top_k,
250
+ "lang": lang,
251
+ "use_custom_embedding": use_custom_embedding,
252
+ "optimize_vocab": optimize_vocab,
253
+ "phonetic_weight": phonetic_weight
254
  }
255
 
256
  for model_type, model_name in zip(model_types, model_names):
 
261
  split_strategy,
262
  chunk_size,
263
  overlap_size,
264
+ custom_separators.split(',') if custom_separators else None,
265
+ lang
266
  )
267
 
268
+ if use_custom_embedding:
269
+ custom_model = create_custom_embedding(chunks)
270
+ embedding_model = CustomEmbeddings(custom_model)
271
+
272
+ if optimize_vocab:
273
+ tokenizer, optimized_chunks = optimize_vocabulary(chunks)
274
+ chunks = optimized_chunks
275
+
276
  results, search_time, vector_store = search_embeddings(
277
  chunks,
278
  embedding_model,
279
  vector_store_type,
280
  search_type,
281
  query,
282
+ top_k,
283
+ lang,
284
+ phonetic_weight
285
  )
286
 
287
  stats = calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model)
 
309
  formatted_results.append(result)
310
  return formatted_results
311
 
 
312
  def launch_interface(share=True):
313
  iface = gr.Interface(
314
  fn=compare_embeddings,
315
  inputs=[
316
  gr.File(label="Upload File (Optional)"),
317
  gr.Textbox(label="Search Query"),
318
+ gr.CheckboxGroup(choices=list(MODELS.keys()) + ["Custom"], label="Embedding Model Types"),
319
+ gr.CheckboxGroup(choices=[model for models in MODELS.values() for model in models] + ["custom_model"], label="Embedding Models"),
320
  gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"),
321
  gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
322
  gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
323
  gr.Textbox(label="Custom Split Separators (comma-separated, optional)"),
324
  gr.Radio(choices=["FAISS", "Chroma"], label="Vector Store Type", value="FAISS"),
325
+ gr.Radio(choices=["similarity", "mmr", "custom"], label="Search Type", value="similarity"),
326
+ gr.Slider(1, 10, step=1, value=5, label="Top K"),
327
+ gr.Dropdown(choices=["german", "english", "french"], label="Language", value="german"),
328
+ gr.Checkbox(label="Use Custom Embedding", value=False),
329
+ gr.Checkbox(label="Optimize Vocabulary", value=False),
330
+ gr.Slider(0, 1, step=0.1, value=0.3, label="Phonetic Matching Weight")
331
  ],
332
  outputs=[
333
  gr.Dataframe(label="Results", interactive=False),
334
  gr.Dataframe(label="Statistics", interactive=False)
335
  ],
336
+ title="Advanced Embedding Comparison Tool",
337
+ description="Compare different embedding models and retrieval strategies with advanced preprocessing and phonetic matching"
 
 
 
 
338
  )
339
 
340
  tutorial_md = """
341
+ # Advanced Embedding Comparison Tool Tutorial
342
 
343
+ ... (update the tutorial to include information about the new features) ...
344
  """
345
 
346
  iface = gr.TabbedInterface(