Chris4K commited on
Commit
84cb849
1 Parent(s): ac567eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -127
app.py CHANGED
@@ -4,35 +4,46 @@ import pdfplumber
4
  import docx
5
  import nltk
6
  import gradio as gr
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_text_splitters import RecursiveCharacterTextSplitter
10
- from langchain_text_splitters import TokenTextSplitter
11
- from sentence_transformers import SentenceTransformer
12
- from transformers import AutoTokenizer
13
- from nltk import sent_tokenize
14
- from typing import List, Tuple
15
- from transformers import AutoModel, AutoTokenizer
16
-
17
- #import spacy
18
- #spacy.cli.download("en_core_web_sm") # Ensure the model is available
19
- #nlp = spacy.load("en_core_web_sm") # Load the model
20
-
21
-
 
 
 
22
 
23
  # Ensure nltk sentence tokenizer is downloaded
24
- nltk.download('punkt')
25
 
26
  FILES_DIR = './files'
27
 
28
  # Supported embedding models
29
  MODELS = {
30
- 'e5-base': "danielheinz/e5-base-sts-en-de",
31
- 'multilingual-e5-base': "multilingual-e5-base",
32
- 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
33
- 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
34
- 'gte-large': "gte-large",
35
- 'gbert-base': "gbert-base"
 
 
 
 
 
 
 
 
36
  }
37
 
38
  class FileHandler:
@@ -63,123 +74,148 @@ class FileHandler:
63
  with open(file_path, 'r', encoding='utf-8') as f:
64
  return f.read()
65
 
66
- class EmbeddingModel:
67
- def __init__(self, model_name, max_tokens=None):
68
- self.model = HuggingFaceEmbeddings(model_name=model_name)
69
- self.max_tokens = max_tokens
70
-
71
- def embed(self, chunks: List[str]):
72
- # Embed the list of chunks
73
- return self.model.embed_documents(chunks)
74
-
75
- def process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens):
76
- print('-----mmm--------')
77
- print(model_name)
78
- print(split_strategy)
79
- print(overlap_size)
80
- print(chunk_size)
81
- print(max_tokens)
82
- # File processing
83
- text = ""
84
- for file in os.listdir(FILES_DIR):
85
- file_path = os.path.join(FILES_DIR, file)
86
- text += FileHandler.extract_text(file_path)
87
 
88
- # Split text into chunks
89
  if split_strategy == 'token':
90
- splitter = TokenTextSplitter(chunk_size=250, chunk_overlap=20)
 
 
 
 
 
 
91
  else:
92
- splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
93
-
94
- chunks = splitter.split_text(text)
95
-
96
- # Embed chunks, not the full text
97
- model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens)
98
- embeddings = model.embed(chunks)
99
- print(chunks)
100
- return embeddings, chunks
101
-
102
- def search_embeddings(query, model_name, top_k):
103
- model = HuggingFaceEmbeddings(model_name=MODELS[model_name])
104
- #embeddings = model.embed_query(query)
105
- embeddings = model.similarity_search(query)
106
- print(embeddings[0])
107
- #query = "What did the president say about Ketanji Brown Jackson"
108
- #docs = db.similarity_search(query)
109
- #print(docs[0].page_content)
110
- # Perform FAISS or other similarity-based search over embeddings
111
- # This part requires you to build and search a FAISS index with embeddings
112
-
113
- return embeddings # You would likely return the top-k results here
114
-
115
- def calculate_statistics(embeddings):
116
- # Return time taken, token count, etc.
117
- return {"tokens": len(embeddings), "time_taken": time.time()}
118
-
119
- import shutil
120
-
121
-
122
- import shutil
123
-
124
- def upload_file(file, model_name, split_strategy, overlap_size,chunk_size, max_tokens, query, top_k):
125
- # Ensure chunk_size and overlap_size are valid integers and provide defaults if needed
126
-
127
- #try:
128
- # if chunk_size is None or chunk_size == "":
129
- # chunk_size = 100 # Default value if not provided
130
- # else:
131
- # chunk_size = int(chunk_size) # Convert to int if valid#
132
-
133
- # if overlap_size is None or overlap_size == "":
134
- # overlap_size = 0 # Default value if not provided
135
- # else:
136
- # overlap_size = int(overlap_size) # Convert to int if valid
137
- #except ValueError:
138
- # return {"error": "Chunk size and overlap size must be valid integers."}
139
- print('-------------')
140
- print(file.name)
141
- print(model_name)
142
- print(split_strategy)
143
- print(overlap_size)
144
- print(chunk_size)
145
- print(max_tokens)
146
- print(query)
147
- print(top_k)
148
-
149
- # Handle file upload using Gradio file object
150
- file_path = file.name # Get the file path from Gradio file object
151
-
152
- # Copy the uploaded file content to a local directory
153
- destination_path = os.path.join(FILES_DIR, os.path.basename(file_path))
154
- shutil.copyfile(file_path, destination_path) # Use shutil to copy the file
155
-
156
- # Process files and get embeddings
157
- embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens)
158
 
159
- # Perform search
160
- results = search_embeddings(query, model_name, top_k)
 
 
 
 
 
161
 
162
- # Calculate statistics
163
- stats = calculate_statistics(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- return {"results": results, "stats": stats}
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Gradio interface
169
  iface = gr.Interface(
170
- fn=upload_file,
171
  inputs=[
172
- gr.File(label="Upload File"),
173
  gr.Textbox(label="Search Query"),
174
- gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"),
175
- gr.Radio(choices=["token", "recursive"], label="Split Strategy"),
176
- gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), # Ensure type is int
177
- gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), # Ensure type is int
178
- gr.Slider(50, 500, step=50, value=200, label="Max Tokens"), # Ensure type is int
179
- gr.Slider(1, 10, step=1, value=5, label="Top K") # Ensure type is int
 
 
 
180
  ],
181
- outputs="json"
 
 
182
  )
183
 
184
- iface.launch()
185
-
 
4
  import docx
5
  import nltk
6
  import gradio as gr
7
+ from langchain.embeddings import (
8
+ HuggingFaceEmbeddings,
9
+ OpenAIEmbeddings,
10
+ CohereEmbeddings,
11
+ )
12
+ from langchain.vectorstores import FAISS, Chroma
13
+ from langchain.text_splitters import (
14
+ RecursiveCharacterTextSplitter,
15
+ TokenTextSplitter,
16
+ )
17
+ from langchain.retrievers import (
18
+ VectorStoreRetriever,
19
+ ContextualCompressionRetriever,
20
+ )
21
+ from langchain.retrievers.document_compressors import LLMChainExtractor
22
+ from langchain.llms import OpenAI
23
+ from typing import List, Dict, Any
24
+ import pandas as pd
25
 
26
  # Ensure nltk sentence tokenizer is downloaded
27
+ nltk.download('punkt', quiet=True)
28
 
29
  FILES_DIR = './files'
30
 
31
  # Supported embedding models
32
  MODELS = {
33
+ 'HuggingFace': {
34
+ 'e5-base': "danielheinz/e5-base-sts-en-de",
35
+ 'multilingual-e5-base': "multilingual-e5-base",
36
+ 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
37
+ 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
38
+ 'gte-large': "gte-large",
39
+ 'gbert-base': "gbert-base"
40
+ },
41
+ 'OpenAI': {
42
+ 'text-embedding-ada-002': "text-embedding-ada-002"
43
+ },
44
+ 'Cohere': {
45
+ 'embed-multilingual-v2.0': "embed-multilingual-v2.0"
46
+ }
47
  }
48
 
49
  class FileHandler:
 
74
  with open(file_path, 'r', encoding='utf-8') as f:
75
  return f.read()
76
 
77
+ def get_embedding_model(model_type, model_name):
78
+ if model_type == 'HuggingFace':
79
+ return HuggingFaceEmbeddings(model_name=MODELS[model_type][model_name])
80
+ elif model_type == 'OpenAI':
81
+ return OpenAIEmbeddings(model=MODELS[model_type][model_name])
82
+ elif model_type == 'Cohere':
83
+ return CohereEmbeddings(model=MODELS[model_type][model_name])
84
+ else:
85
+ raise ValueError(f"Unsupported model type: {model_type}")
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ def get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators=None):
88
  if split_strategy == 'token':
89
+ return TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size)
90
+ elif split_strategy == 'recursive':
91
+ return RecursiveCharacterTextSplitter(
92
+ chunk_size=chunk_size,
93
+ chunk_overlap=overlap_size,
94
+ separators=custom_separators or ["\n\n", "\n", " ", ""]
95
+ )
96
  else:
97
+ raise ValueError(f"Unsupported split strategy: {split_strategy}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ def get_vector_store(store_type, texts, embedding_model):
100
+ if store_type == 'FAISS':
101
+ return FAISS.from_texts(texts, embedding_model)
102
+ elif store_type == 'Chroma':
103
+ return Chroma.from_texts(texts, embedding_model)
104
+ else:
105
+ raise ValueError(f"Unsupported vector store type: {store_type}")
106
 
107
+ def get_retriever(vector_store, search_type, search_kwargs=None):
108
+ if search_type == 'similarity':
109
+ return vector_store.as_retriever(search_type="similarity", search_kwargs=search_kwargs)
110
+ elif search_type == 'mmr':
111
+ return vector_store.as_retriever(search_type="mmr", search_kwargs=search_kwargs)
112
+ else:
113
+ raise ValueError(f"Unsupported search type: {search_type}")
114
+
115
+ def process_files(file_path, model_type, model_name, split_strategy, chunk_size, overlap_size, custom_separators):
116
+ # File processing
117
+ if file_path:
118
+ text = FileHandler.extract_text(file_path)
119
+ else:
120
+ text = ""
121
+ for file in os.listdir(FILES_DIR):
122
+ file_path = os.path.join(FILES_DIR, file)
123
+ text += FileHandler.extract_text(file_path)
124
+
125
+ # Split text into chunks
126
+ text_splitter = get_text_splitter(split_strategy, chunk_size, overlap_size, custom_separators)
127
+ chunks = text_splitter.split_text(text)
128
+
129
+ # Get embedding model
130
+ embedding_model = get_embedding_model(model_type, model_name)
131
 
132
+ return chunks, embedding_model
133
 
134
+ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, query, top_k):
135
+ # Create vector store
136
+ vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
137
+
138
+ # Get retriever
139
+ retriever = get_retriever(vector_store, search_type, {"k": top_k})
140
+
141
+ # Perform search
142
+ start_time = time.time()
143
+ results = retriever.get_relevant_documents(query)
144
+ end_time = time.time()
145
+
146
+ return results, end_time - start_time
147
+
148
+ def calculate_statistics(results, search_time):
149
+ return {
150
+ "num_results": len(results),
151
+ "avg_content_length": sum(len(doc.page_content) for doc in results) / len(results),
152
+ "search_time": search_time
153
+ }
154
+
155
+ def format_results(results, stats):
156
+ df = pd.DataFrame([
157
+ {
158
+ "Content": doc.page_content,
159
+ "Source": doc.metadata.get("source", "Unknown"),
160
+ "Relevance Score": doc.metadata.get("score", "N/A")
161
+ } for doc in results
162
+ ])
163
+
164
+ formatted_stats = pd.DataFrame([stats])
165
+
166
+ return gr.DataFrame(df), gr.DataFrame(formatted_stats)
167
+
168
+ def compare_embeddings(file, query, model_types, model_names, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k):
169
+ all_results = []
170
+ all_stats = []
171
+
172
+ for model_type, model_name in zip(model_types, model_names):
173
+ chunks, embedding_model = process_files(
174
+ file.name if file else None,
175
+ model_type,
176
+ model_name,
177
+ split_strategy,
178
+ chunk_size,
179
+ overlap_size,
180
+ custom_separators.split(',') if custom_separators else None
181
+ )
182
+
183
+ results, search_time = search_embeddings(
184
+ chunks,
185
+ embedding_model,
186
+ vector_store_type,
187
+ search_type,
188
+ query,
189
+ top_k
190
+ )
191
+
192
+ stats = calculate_statistics(results, search_time)
193
+ stats["model"] = f"{model_type} - {model_name}"
194
+
195
+ all_results.append(results)
196
+ all_stats.append(stats)
197
+
198
+ return [format_results(results, stats) for results, stats in zip(all_results, all_stats)]
199
 
200
  # Gradio interface
201
  iface = gr.Interface(
202
+ fn=compare_embeddings,
203
  inputs=[
204
+ gr.File(label="Upload File (Optional)"),
205
  gr.Textbox(label="Search Query"),
206
+ gr.CheckboxGroup(choices=list(MODELS.keys()), label="Embedding Model Types", value=["HuggingFace"]),
207
+ gr.CheckboxGroup(choices=[model for models in MODELS.values() for model in models], label="Embedding Models", value=["e5-base"]),
208
+ gr.Radio(choices=["token", "recursive"], label="Split Strategy", value="recursive"),
209
+ gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
210
+ gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
211
+ gr.Textbox(label="Custom Split Separators (comma-separated, optional)"),
212
+ gr.Radio(choices=["FAISS", "Chroma"], label="Vector Store Type", value="FAISS"),
213
+ gr.Radio(choices=["similarity", "mmr"], label="Search Type", value="similarity"),
214
+ gr.Slider(1, 10, step=1, value=5, label="Top K")
215
  ],
216
+ outputs=[gr.DataFrame(label="Results"), gr.DataFrame(label="Statistics")] * len(MODELS),
217
+ title="Embedding Comparison Tool",
218
+ description="Compare different embedding models and retrieval strategies"
219
  )
220
 
221
+ iface.launch()