seanpedrickcase commited on
Commit
a95ef9f
1 Parent(s): 2393537

General code improvements and refinements.

Browse files
Dockerfile CHANGED
@@ -58,7 +58,5 @@ WORKDIR $HOME/app
58
 
59
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
60
  COPY --chown=user . $HOME/app
61
- #COPY . $HOME/app
62
-
63
 
64
  CMD ["python", "app.py"]
 
58
 
59
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
60
  COPY --chown=user . $HOME/app
 
 
61
 
62
  CMD ["python", "app.py"]
app.py CHANGED
@@ -7,7 +7,7 @@ PandasDataFrame = Type[pd.DataFrame]
7
 
8
  from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
9
  from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
10
- from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_simple_retrieval
11
  from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
12
  from search_funcs.spacy_search_funcs import spacy_fuzzy_search
13
  from search_funcs.aws_functions import load_data_from_aws
@@ -17,39 +17,33 @@ temp_folder_path = get_temp_folder_path()
17
  empty_folder(temp_folder_path)
18
 
19
  ## Gradio app - BM25 search
20
- block = gr.Blocks(theme = gr.themes.Base()) # , css="theme.css"
21
 
22
-
23
- with block:
24
  print("Please don't close this window! Open the below link in the web browser of your choice.")
25
 
26
- ingest_text = gr.State()
27
- ingest_metadata = gr.State()
28
- ingest_docs = gr.State()
29
- vectorstore_state = gr.State() # globals()["vectorstore"]
30
- embeddings_state = gr.State(np.array([])) # globals()["embeddings"]
31
- search_index_state = gr.State()
32
- tokenised_state = gr.State()
33
- bm25_search_object_state = gr.State()
34
-
35
- k_val = gr.State(9999)
36
- out_passages = gr.State(9999)
37
- vec_weight = gr.State(1)
38
-
39
- corpus_state = gr.State()
40
- keyword_data_list_state = gr.State([])
41
- join_data_state = gr.State(pd.DataFrame())
42
- output_file_state = gr.State([])
43
-
44
- orig_keyword_data_state = gr.State(pd.DataFrame())
45
- keyword_data_state = gr.State(pd.DataFrame())
46
-
47
- orig_semantic_data_state = gr.State(pd.DataFrame())
48
- semantic_data_state = gr.State(pd.DataFrame())
49
 
 
 
 
 
 
 
 
 
50
  session_hash_state = gr.State("")
51
  s3_output_folder_state = gr.State("")
 
 
52
 
 
53
  in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
54
  presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
55
  that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
@@ -167,7 +161,7 @@ depends on factors such as the type of documents or queries. Information taken f
167
  out_aws_data_message = gr.Textbox(label="AWS data load progress")
168
 
169
  # Changing search parameters button
170
- in_search_param_button.click(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message])
171
 
172
  # ---
173
  in_k1_button.click(display_info, inputs=in_k1_info)
@@ -178,43 +172,41 @@ depends on factors such as the type of documents or queries. Information taken f
178
  ### Loading AWS data ###
179
  load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
180
  load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
181
-
182
 
183
  ### BM25 SEARCH ###
184
  # Update dropdowns upon initial file load
185
- in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, keyword_data_state, orig_keyword_data_state, search_index_state, embeddings_state, tokenised_state, load_finished_message, current_source], api_name="initial_load")
186
  in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
187
 
188
  # Load in BM25 data
189
- load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, keyword_data_state, tokenised_state, in_clean_data, return_intermediate_files], outputs=[corpus_state, load_finished_message, keyword_data_state, output_file, output_file, keyword_data_list_state, in_bm25_column], api_name="load_keyword").\
190
- then(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_object_state], api_name="prepare_keyword")
191
-
192
 
193
  # BM25 search functions on click or enter
194
- keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file], api_name="keyword_search")
195
- keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file])
196
 
197
  # Fuzzy search functions on click
198
- fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, keyword_data_list_state, keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
199
 
200
  ### SEMANTIC SEARCH ###
201
 
202
  # Load in a csv/excel file for semantic search
203
- in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state, search_index_state, embeddings_state, tokenised_state, semantic_load_progress, current_source_semantic])
204
  load_semantic_data_button.click(
205
- csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[ingest_docs, semantic_load_progress, output_file_state]).\
206
- then(docs_to_bge_embed_np_array, inputs=[ingest_docs, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, vectorstore_state, semantic_output_file, output_file_state])
207
 
208
  # Semantic search query
209
- semantic_submit.click(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
210
- semantic_query.submit(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
211
 
212
- block.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
213
 
214
  # Launch the Gradio app
215
  if __name__ == "__main__":
216
- block.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
217
 
218
  # Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
219
- # block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
220
  # ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
 
7
 
8
  from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
9
  from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
10
+ from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_semantic_search
11
  from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
12
  from search_funcs.spacy_search_funcs import spacy_fuzzy_search
13
  from search_funcs.aws_functions import load_data_from_aws
 
17
  empty_folder(temp_folder_path)
18
 
19
  ## Gradio app - BM25 search
20
+ app = gr.Blocks(theme = gr.themes.Base()) # , css="theme.css"
21
 
22
+ with app:
 
23
  print("Please don't close this window! Open the below link in the web browser of your choice.")
24
 
25
+ # BM25 state objects
26
+ orig_keyword_data_state = gr.State(pd.DataFrame()) # Original data that is not changed #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
27
+ prepared_keyword_data_state = gr.State(pd.DataFrame()) # Data frame the contains modified data #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
28
+ #tokenised_prepared_keyword_data_state = gr.State([]) # This is data that has been loaded in as tokens #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State()
29
+ tokenised_prepared_keyword_data_state = gr.State([]) # Data that has been prepared for search (tokenised) #gr.Dataframe(np.array([]), type="array", visible=False) #gr.State([])
30
+ bm25_search_index_state = gr.State()
31
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Semantic search state objects
34
+ orig_semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
35
+ semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
36
+ semantic_input_document_format = gr.State([])
37
+ embeddings_state = gr.State(np.array([])) #gr.Dataframe(np.array([]), type="numpy", visible=False) #gr.State(np.array([])) # globals()["embeddings"]
38
+ semantic_k_val = gr.Number(9999, visible=False)
39
+
40
+ # State objects for app in general
41
  session_hash_state = gr.State("")
42
  s3_output_folder_state = gr.State("")
43
+ join_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(), visible=False) #gr.State(pd.DataFrame())
44
+ output_file_state = gr.Dropdown([], visible=False, allow_custom_value=True) #gr.Dataframe(type="array", visible=False) #gr.State([])
45
 
46
+ # Informational state objects
47
  in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
48
  presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
49
  that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
 
161
  out_aws_data_message = gr.Textbox(label="AWS data load progress")
162
 
163
  # Changing search parameters button
164
+ in_search_param_button.click(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message])
165
 
166
  # ---
167
  in_k1_button.click(display_info, inputs=in_k1_info)
 
172
  ### Loading AWS data ###
173
  load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
174
  load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
 
175
 
176
  ### BM25 SEARCH ###
177
  # Update dropdowns upon initial file load
178
+ in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, prepared_keyword_data_state, orig_keyword_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, load_finished_message, current_source], api_name="initial_load")
179
  in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
180
 
181
  # Load in BM25 data
182
+ load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, prepared_keyword_data_state, tokenised_prepared_keyword_data_state, in_clean_data, return_intermediate_files], outputs=[tokenised_prepared_keyword_data_state, load_finished_message, prepared_keyword_data_state, output_file, output_file, in_bm25_column], api_name="load_keyword").\
183
+ then(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_index_state, tokenised_prepared_keyword_data_state], api_name="prepare_keyword") # keyword_data_list_state
 
184
 
185
  # BM25 search functions on click or enter
186
+ keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file], api_name="keyword_search")
187
+ keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file])
188
 
189
  # Fuzzy search functions on click
190
+ fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, tokenised_prepared_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
191
 
192
  ### SEMANTIC SEARCH ###
193
 
194
  # Load in a csv/excel file for semantic search
195
+ in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, semantic_load_progress, current_source_semantic])
196
  load_semantic_data_button.click(
197
+ csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[semantic_input_document_format, semantic_load_progress, output_file_state]).\
198
+ then(docs_to_bge_embed_np_array, inputs=[semantic_input_document_format, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, embeddings_state, semantic_output_file, output_file_state]) # vectorstore_state
199
 
200
  # Semantic search query
201
+ semantic_submit.click(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
202
+ semantic_query.submit(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
203
 
204
+ app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
205
 
206
  # Launch the Gradio app
207
  if __name__ == "__main__":
208
+ app.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
209
 
210
  # Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
211
+ # app.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
212
  # ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
requirements.txt CHANGED
@@ -1,12 +1,11 @@
1
  pandas==2.2.2
2
  polars==0.20.3
3
  pyarrow==14.0.2
4
- openpyxl==3.1.2
5
  torch==2.3.1
6
- transformers==4.41.2
7
  spacy
8
  en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
9
  gradio
10
  sentence_transformers==3.0.1
11
- lxml==5.1.0
12
  boto3==1.34.103
 
1
  pandas==2.2.2
2
  polars==0.20.3
3
  pyarrow==14.0.2
4
+ openpyxl==3.1.3
5
  torch==2.3.1
 
6
  spacy
7
  en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
8
  gradio
9
  sentence_transformers==3.0.1
10
+ lxml==5.2.2
11
  boto3==1.34.103
requirements_gpu.txt CHANGED
@@ -1,11 +1,11 @@
1
  pandas==2.2.2
2
  polars==0.20.3
3
  pyarrow==14.0.2
4
- openpyxl==3.1.2
5
  torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
6
  spacy
7
  en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
8
  gradio
9
- sentence_transformers==2.3.1
10
- lxml==5.1.0
11
  boto3==1.34.103
 
1
  pandas==2.2.2
2
  polars==0.20.3
3
  pyarrow==14.0.2
4
+ openpyxl==3.1.3
5
  torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
6
  spacy
7
  en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
8
  gradio
9
+ sentence_transformers==3.0.1
10
+ lxml==5.2.2
11
  boto3==1.34.103
search_funcs/bm25_functions.py CHANGED
@@ -8,6 +8,7 @@ import time
8
  import pandas as pd
9
  from numpy import inf
10
  import gradio as gr
 
11
 
12
  from datetime import datetime
13
 
@@ -165,7 +166,7 @@ class BM25:
165
  return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
166
 
167
 
168
- def get_top_n_with_score(self, query, documents, n=5):
169
  """
170
  Retrieve the top n documents for the query along with their scores.
171
 
@@ -229,15 +230,47 @@ class BM25:
229
  with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
230
  return pickle.load(fsave)
231
 
232
- # These following functions are my own work
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, clean="No", return_intermediate_files = "No", progress=gr.Progress(track_tqdm=True)):
235
- #print(in_file)
236
  ensure_output_folder_exists(output_folder)
237
 
238
  if not in_file:
239
  print("No input file found. Please load in at least one file.")
240
- return None, "No input file found. Please load in at least one file.", data_state, None, None, [], gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
241
 
242
  progress(0, desc = "Loading in data")
243
  file_list = [string.name for string in in_file]
@@ -247,25 +280,24 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
247
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
248
 
249
  if not data_file_names:
250
- return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None, [], gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
251
 
252
  if not text_column:
253
- return None, "Please enter a column name to search.", data_state, None, None,[], gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
254
 
255
  data_file_name = data_file_names[0]
256
 
257
  df = data_state #read_file(data_file_name)
258
- data_file_out_name = get_file_path_end_with_ext(data_file_name)
259
  data_file_out_name_no_ext = get_file_path_end(data_file_name)
260
 
261
- ## Load in pre-tokenised corpus if exists
262
- tokenised_df = pd.DataFrame()
263
 
264
- tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
265
  search_index_file_names = [string for string in file_list if "gz" in string.lower()]
266
 
267
- print("Dataframe columns:", df.columns)
268
-
269
  df[text_column] = df[text_column].astype(str).str.lower()
270
 
271
  if "copy_of_case_note_id" in df.columns:
@@ -273,10 +305,10 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
273
  df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
274
 
275
  if search_index_file_names:
276
- corpus = list(df[text_column])
277
  message = "Tokenisation skipped - loading search index from file."
278
  print(message)
279
- return corpus, message, df, None, None, [], gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
280
 
281
 
282
  if clean == "Yes":
@@ -285,11 +317,11 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
285
  print("Starting data clean.")
286
 
287
  #df = df.drop_duplicates(text_column)
288
- df_list = list(df[text_column])
289
- df_list = initial_clean(df_list)
290
 
291
  # Save to file if you have cleaned the data
292
- out_file_name, text_column, df = save_prepared_bm25_data(data_file_name, df_list, df, text_column)
293
 
294
  clean_toc = time.perf_counter()
295
  clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
@@ -297,7 +329,7 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
297
 
298
  else:
299
  # Don't clean or save file to disk
300
- df_list = list(df[text_column])
301
  print("No data cleaning performed")
302
  out_file_name = None
303
 
@@ -305,24 +337,27 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
305
 
306
  progress(0.4, desc = "Tokenising text")
307
 
 
 
308
  if tokenised_state:
309
- tokenised_df = tokenised_state
310
- corpus = tokenised_df.iloc[:,0].tolist()
311
  print("Tokenised data loaded from file")
312
- #print("Corpus is: ", corpus[0:5])
 
313
 
314
  else:
315
  tokeniser_tic = time.perf_counter()
316
- corpus = []
317
  batch_size = 256
318
- for doc in tokenizer.pipe(progress.tqdm(df_list, desc = "Tokenising text", unit = "rows"), batch_size=batch_size):
319
- corpus.append([token.text for token in doc])
320
 
321
  tokeniser_toc = time.perf_counter()
322
  tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
323
  print(tokenizer_time_out)
 
324
 
325
- if len(df_list) >= 20:
326
  message = "Data loaded"
327
  else:
328
  message = "Data loaded. Warning: dataset may be too short to get consistent search results."
@@ -334,13 +369,29 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
334
  else:
335
  tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
336
 
337
- pd.DataFrame(data={"Corpus":corpus}).to_parquet(tokenised_data_file_name)
 
 
 
 
338
 
339
- return corpus, message, df, out_file_name, tokenised_data_file_name, df_list, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
 
 
 
 
 
340
 
341
- return corpus, message, df, out_file_name, None, df_list, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
 
 
 
 
 
342
 
343
- def save_prepared_bm25_data(in_file_name, prepared_text_list, in_df, in_bm25_column, progress=gr.Progress(track_tqdm=True)):
 
 
344
 
345
  ensure_output_folder_exists(output_folder)
346
 
@@ -368,26 +419,54 @@ def save_prepared_bm25_data(in_file_name, prepared_text_list, in_df, in_bm25_col
368
 
369
  return file_name, new_text_column, prepared_df
370
 
371
- def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_intermediate_files, k1=1.5, b = 0.75, alpha=-5, progress=gr.Progress(track_tqdm=True)):
372
- #bm25.save("saved_df_bm25")
373
- #bm25 = BM25.load(re.sub(r'\.pkl$', '', file_in.name))
374
-
375
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  if not in_file:
378
  out_message ="No input file found. Please load in at least one file."
379
  print(out_message)
380
- return out_message, None
381
 
382
- if not corpus:
383
  out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
384
  print(out_message)
385
- return out_message, None
386
 
387
  if not text_column:
388
  out_message = "Please enter a column name to search."
389
  print(out_message)
390
- return out_message, None
391
 
392
  file_list = [string.name for string in in_file]
393
 
@@ -397,36 +476,23 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
397
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
398
 
399
  if not data_file_names:
400
- return "Please load in at least one csv/Excel/parquet data file.", None
401
 
402
  data_file_name = data_file_names[0]
403
  data_file_out_name = get_file_path_end_with_ext(data_file_name)
404
  data_file_name_no_ext = get_file_path_end(data_file_name)
405
 
406
- # Check if there is a search index file already
407
- #index_file_names = [string for string in file_list if "gz" in string.lower()]
408
-
409
  progress(0.6, desc = "Preparing search index")
410
 
411
- #if index_file_names:
412
  if search_index:
413
- #index_file_name = index_file_names[0]
414
-
415
- #print(index_file_name)
416
-
417
- bm25_load = search_index
418
-
419
-
420
- #index_file_out_name = get_file_path_end_with_ext(index_file_name)
421
- #index_file_name_no_ext = get_file_path_end(index_file_name)
422
-
423
  else:
424
- print("Preparing BM25 corpus")
425
 
426
- bm25_load = BM25(corpus, k1=k1, b=b, alpha=alpha)
427
 
428
- global bm25
429
- bm25 = bm25_load
430
 
431
  if return_intermediate_files == "Yes":
432
  print("Saving search index file")
@@ -451,7 +517,7 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
451
 
452
  print(message)
453
 
454
- return message, None, bm25
455
 
456
  def convert_bm25_query_to_tokens(free_text_query, clean="No"):
457
  '''
@@ -474,9 +540,75 @@ def convert_bm25_query_to_tokens(free_text_query, clean="No"):
474
 
475
  return out_query
476
 
477
- def bm25_search(free_text_query, in_no_search_results, original_data, searched_data, text_column, in_join_file, clean, bm25, in_join_column = "", search_df_join_column = "", progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
  progress(0, desc = "Conducting keyword search")
 
 
 
 
 
 
 
 
 
480
 
481
  # Prepare query
482
  if (clean == "Yes") | (text_column.endswith("_cleaned")):
@@ -484,8 +616,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
484
  else:
485
  token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
486
 
487
- #print(token_query)
488
-
489
  # Perform search
490
  print("Searching")
491
 
@@ -504,7 +634,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
504
 
505
  # Join scores onto searched data
506
  results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
507
-
508
 
509
 
510
  # Join on data from duplicate case notes
@@ -516,33 +645,27 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
516
  print("Clean is yes")
517
  orig_text_column = text_column.replace("_cleaned", "")
518
 
519
- #print(orig_text_column)
520
- #print(original_data.columns)
521
-
522
  original_data["original_note_id"] = original_data["copy_of_case_note_id"]
523
  original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
524
 
525
  results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
526
  results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
527
  results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
528
-
529
- #results_df_out = pd.concat([results_df_out, original_data[~original_data["copy_of_case_note_id"].isna()][["copy_of_case_note_id", "person_id"]]])
530
- # Replace NaN with an empty string
531
- # results_df_out.fillna('', inplace=True)
532
-
533
-
534
 
 
 
535
  # Join on additional files
536
  if not in_join_file.empty:
537
  progress(0.5, desc = "Joining on additional data file")
538
- join_df = in_join_file
539
- join_df[in_join_column] = join_df[in_join_column].astype(str).str.replace("\.0$","", regex=True)
 
540
  results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
541
 
542
  # Duplicates dropped so as not to expand out dataframe
543
- join_df = join_df.drop_duplicates(in_join_column)
544
 
545
- results_df_out = results_df_out.merge(join_df,left_on=search_df_join_column, right_on=in_join_column, how="left", suffixes=('','_y'))#.drop(in_join_column, axis=1)
546
 
547
  # Reorder results by score, and whether there is text
548
  results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
@@ -559,7 +682,7 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
559
  # Highlight found text and save to file
560
  results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
561
  results_df_out_wb.save(results_df_name)
562
- #results_df_out.to_excel(results_df_name, index= None)
563
  results_first_text = results_df_out[text_column].iloc[0]
564
 
565
  print("Returning results")
 
8
  import pandas as pd
9
  from numpy import inf
10
  import gradio as gr
11
+ from typing import List
12
 
13
  from datetime import datetime
14
 
 
166
  return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
167
 
168
 
169
+ def get_top_n_with_score(self, query:str, documents:List[str], n=5):
170
  """
171
  Retrieve the top n documents for the query along with their scores.
172
 
 
230
  with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
231
  return pickle.load(fsave)
232
 
233
+ def prepare_bm25_input_data(
234
+ in_file: list,
235
+ text_column: str,
236
+ data_state: pd.DataFrame,
237
+ tokenised_state: list,
238
+ clean: str = "No",
239
+ return_intermediate_files: str = "No",
240
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
241
+ ) -> tuple:
242
+ """
243
+ Prepare BM25 input data by loading, cleaning, and tokenizing the text data.
244
+
245
+ Parameters
246
+ ----------
247
+ in_file: list
248
+ List of input files to be processed.
249
+ text_column: str
250
+ The name of the text column in the data file to search.
251
+ data_state: pd.DataFrame
252
+ The current state of the data.
253
+ tokenised_state: list
254
+ The current state of the tokenized data.
255
+ clean: str, optional
256
+ Whether to clean the text data (default is "No").
257
+ return_intermediate_files: str, optional
258
+ Whether to return intermediate processing files (default is "No").
259
+ progress: gr.Progress, optional
260
+ Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
261
+
262
+ Returns
263
+ -------
264
+ tuple
265
+ A tuple containing the prepared search text list, a message, the updated data state,
266
+ the tokenized data, the search index, and a dropdown component for the text column.
267
+ """
268
 
 
 
269
  ensure_output_folder_exists(output_folder)
270
 
271
  if not in_file:
272
  print("No input file found. Please load in at least one file.")
273
+ return None, "No input file found. Please load in at least one file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
274
 
275
  progress(0, desc = "Loading in data")
276
  file_list = [string.name for string in in_file]
 
280
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
281
 
282
  if not data_file_names:
283
+ return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
284
 
285
  if not text_column:
286
+ return None, "Please enter a column name to search.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
287
 
288
  data_file_name = data_file_names[0]
289
 
290
  df = data_state #read_file(data_file_name)
291
+ #data_file_out_name = get_file_path_end_with_ext(data_file_name)
292
  data_file_out_name_no_ext = get_file_path_end(data_file_name)
293
 
294
+ ## Load in pre-tokenised prepared_search_text_list if exists
295
+ #tokenised_df = pd.DataFrame()
296
 
297
+ #tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
298
  search_index_file_names = [string for string in file_list if "gz" in string.lower()]
299
 
300
+ # Set all search text to lower case
 
301
  df[text_column] = df[text_column].astype(str).str.lower()
302
 
303
  if "copy_of_case_note_id" in df.columns:
 
305
  df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
306
 
307
  if search_index_file_names:
308
+ prepared_search_text_list = list(df[text_column])
309
  message = "Tokenisation skipped - loading search index from file."
310
  print(message)
311
+ return prepared_search_text_list, message, df, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
312
 
313
 
314
  if clean == "Yes":
 
317
  print("Starting data clean.")
318
 
319
  #df = df.drop_duplicates(text_column)
320
+ prepared_text_as_list = list(df[text_column])
321
+ prepared_text_as_list = initial_clean(prepared_text_as_list)
322
 
323
  # Save to file if you have cleaned the data
324
+ out_file_name, text_column, df = save_prepared_bm25_data(data_file_name, prepared_text_as_list, df, text_column)
325
 
326
  clean_toc = time.perf_counter()
327
  clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
 
329
 
330
  else:
331
  # Don't clean or save file to disk
332
+ prepared_text_as_list = list(df[text_column])
333
  print("No data cleaning performed")
334
  out_file_name = None
335
 
 
337
 
338
  progress(0.4, desc = "Tokenising text")
339
 
340
+ print("Tokenised state:", tokenised_state)
341
+
342
  if tokenised_state:
343
+ prepared_search_text_list = tokenised_state.iloc[:,0].tolist()
 
344
  print("Tokenised data loaded from file")
345
+
346
+ #print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
347
 
348
  else:
349
  tokeniser_tic = time.perf_counter()
350
+ prepared_search_text_list = []
351
  batch_size = 256
352
+ for doc in tokenizer.pipe(progress.tqdm(prepared_text_as_list, desc = "Tokenising text", unit = "rows"), batch_size=batch_size):
353
+ prepared_search_text_list.append([token.text for token in doc])
354
 
355
  tokeniser_toc = time.perf_counter()
356
  tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
357
  print(tokenizer_time_out)
358
+ #print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
359
 
360
+ if len(prepared_text_as_list) >= 20:
361
  message = "Data loaded"
362
  else:
363
  message = "Data loaded. Warning: dataset may be too short to get consistent search results."
 
369
  else:
370
  tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
371
 
372
+ pd.DataFrame(data={"prepared_search_text_list":prepared_search_text_list}).to_parquet(tokenised_data_file_name)
373
+
374
+ return prepared_search_text_list, message, df, out_file_name, tokenised_data_file_name, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
375
+
376
+ return prepared_search_text_list, message, df, out_file_name, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
377
 
378
+ def save_prepared_bm25_data(in_file_name: str, prepared_text_list: list, in_df: pd.DataFrame, in_bm25_column: str, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> tuple:
379
+ """
380
+ Save the prepared BM25 data to a file.
381
+
382
+ This function ensures the output folder exists, checks if the length of the prepared text list matches the input dataframe,
383
+ and saves the prepared data to a file in the specified format. The original column in the input dataframe is dropped to reduce file size.
384
 
385
+ Parameters:
386
+ - in_file_name (str): The name of the input file.
387
+ - prepared_text_list (list): The list of prepared text.
388
+ - in_df (pd.DataFrame): The input dataframe.
389
+ - in_bm25_column (str): The name of the column to be processed.
390
+ - progress (gr.Progress, optional): The progress tracker for the operation.
391
 
392
+ Returns:
393
+ - tuple: A tuple containing the file name, new text column name, and the prepared dataframe.
394
+ """
395
 
396
  ensure_output_folder_exists(output_folder)
397
 
 
419
 
420
  return file_name, new_text_column, prepared_df
421
 
422
+ def prepare_bm25(
423
+ prepared_search_text_list: List[str],
424
+ in_file: List[gr.File],
425
+ text_column: str,
426
+ search_index: BM25,
427
+ clean: str,
428
+ return_intermediate_files: str,
429
+ k1: float = 1.5,
430
+ b: float = 0.75,
431
+ alpha: float = -5,
432
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
433
+ ) -> tuple:
434
+ """
435
+ Prepare the BM25 search index.
436
+
437
+ This function prepares the BM25 search index from the provided text list and input file. It ensures the necessary
438
+ files and columns are present, processes the data, and optionally saves intermediate files.
439
+
440
+ Parameters:
441
+ - prepared_search_text_list (List[str]): The list of prepared search text.
442
+ - in_file (List[gr.File]): The list of input files.
443
+ - text_column (str): The name of the column to search.
444
+ - search_index (BM25): The BM25 search index.
445
+ - clean (str): Indicates whether to clean the data.
446
+ - return_intermediate_files (str): Indicates whether to return intermediate files.
447
+ - k1 (float, optional): The k1 parameter for BM25. Default is 1.5.
448
+ - b (float, optional): The b parameter for BM25. Default is 0.75.
449
+ - alpha (float, optional): The alpha parameter for BM25. Default is -5.
450
+ - progress (gr.Progress, optional): The progress tracker for the operation.
451
+
452
+ Returns:
453
+ - tuple: A tuple containing the output message, BM25 search index, and other relevant information.
454
+ """
455
 
456
  if not in_file:
457
  out_message ="No input file found. Please load in at least one file."
458
  print(out_message)
459
+ return out_message, None, None
460
 
461
+ if not prepared_search_text_list:
462
  out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
463
  print(out_message)
464
+ return out_message, None, None, None
465
 
466
  if not text_column:
467
  out_message = "Please enter a column name to search."
468
  print(out_message)
469
+ return out_message, None, None, None
470
 
471
  file_list = [string.name for string in in_file]
472
 
 
476
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
477
 
478
  if not data_file_names:
479
+ return "Please load in at least one csv/Excel/parquet data file.", None, None, None
480
 
481
  data_file_name = data_file_names[0]
482
  data_file_out_name = get_file_path_end_with_ext(data_file_name)
483
  data_file_name_no_ext = get_file_path_end(data_file_name)
484
 
 
 
 
485
  progress(0.6, desc = "Preparing search index")
486
 
 
487
  if search_index:
488
+ bm25 = search_index
 
 
 
 
 
 
 
 
 
489
  else:
490
+ print("Preparing BM25 search corpus")
491
 
492
+ bm25 = BM25(prepared_search_text_list, k1=k1, b=b, alpha=alpha)
493
 
494
+ #global bm25
495
+ #bm25 = bm25_load
496
 
497
  if return_intermediate_files == "Yes":
498
  print("Saving search index file")
 
517
 
518
  print(message)
519
 
520
+ return message, None, bm25, prepared_search_text_list
521
 
522
  def convert_bm25_query_to_tokens(free_text_query, clean="No"):
523
  '''
 
540
 
541
  return out_query
542
 
543
+ def bm25_search(
544
+ free_text_query: str,
545
+ in_no_search_results: int,
546
+ original_data: pd.DataFrame,
547
+ searched_data: pd.DataFrame,
548
+ text_column: str,
549
+ in_join_file: str,
550
+ clean: str,
551
+ bm25: BM25,
552
+ prepared_search_text_list_state: list,
553
+ in_join_column: str = "",
554
+ search_df_join_column: str = "",
555
+ k1: float = 1.5,
556
+ b: float = 0.75,
557
+ alpha: float = -5,
558
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
559
+ ) -> tuple:
560
+ """
561
+ Perform a BM25 search on the provided text data.
562
+
563
+ Parameters
564
+ ----------
565
+ free_text_query : str
566
+ The query text to search for.
567
+ in_no_search_results : int
568
+ The number of search results to return.
569
+ original_data : pd.DataFrame
570
+ The original data containing the text to be searched.
571
+ searched_data : pd.DataFrame
572
+ The data that has been prepared for searching.
573
+ text_column : str
574
+ The name of the column in the data to search.
575
+ in_join_file : str
576
+ The file to join the search results with.
577
+ clean : str
578
+ Whether to clean the text data.
579
+ bm25 : BM25
580
+ The BM25 object used for searching.
581
+ prepared_search_text_list_state : list
582
+ The state of the prepared search text list.
583
+ in_join_column : str, optional
584
+ The column to join on in the input file (default is "").
585
+ search_df_join_column : str, optional
586
+ The column to join on in the search dataframe (default is "").
587
+ k1 : float, optional
588
+ The k1 parameter for BM25 (default is 1.5).
589
+ b : float, optional
590
+ The b parameter for BM25 (default is 0.75).
591
+ alpha : float, optional
592
+ The alpha parameter for BM25 (default is -5).
593
+ progress : gr.Progress, optional
594
+ Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
595
+
596
+ Returns
597
+ -------
598
+ tuple
599
+ A tuple containing a message, the search results file name (if any), the BM25 object, and the prepared search text list.
600
+ """
601
 
602
  progress(0, desc = "Conducting keyword search")
603
+
604
+ print("in_join_file at start of bm25_search:", in_join_file)
605
+
606
+ if not bm25:
607
+ print("Preparing BM25 search corpus")
608
+
609
+ bm25 = BM25(prepared_search_text_list_state, k1=k1, b=b, alpha=alpha)
610
+
611
+ # print("bm25:", bm25)
612
 
613
  # Prepare query
614
  if (clean == "Yes") | (text_column.endswith("_cleaned")):
 
616
  else:
617
  token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
618
 
 
 
619
  # Perform search
620
  print("Searching")
621
 
 
634
 
635
  # Join scores onto searched data
636
  results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
 
637
 
638
 
639
  # Join on data from duplicate case notes
 
645
  print("Clean is yes")
646
  orig_text_column = text_column.replace("_cleaned", "")
647
 
 
 
 
648
  original_data["original_note_id"] = original_data["copy_of_case_note_id"]
649
  original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
650
 
651
  results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
652
  results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
653
  results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
 
 
 
 
 
 
654
 
655
+ print("in_join_file:", in_join_file)
656
+
657
  # Join on additional files
658
  if not in_join_file.empty:
659
  progress(0.5, desc = "Joining on additional data file")
660
+ #join_df = in_join_file
661
+ # Prepare join columns as string and remove .0 at end of stringified numbers
662
+ in_join_file[in_join_column] = in_join_file[in_join_column].astype(str).str.replace("\.0$","", regex=True)
663
  results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
664
 
665
  # Duplicates dropped so as not to expand out dataframe
666
+ in_join_file = in_join_file.drop_duplicates(in_join_column)
667
 
668
+ results_df_out = results_df_out.merge(in_join_file,left_on=search_df_join_column, right_on=in_join_column, how="left", suffixes=('','_y'))#.drop(in_join_column, axis=1)
669
 
670
  # Reorder results by score, and whether there is text
671
  results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
 
682
  # Highlight found text and save to file
683
  results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
684
  results_df_out_wb.save(results_df_name)
685
+
686
  results_first_text = results_df_out[text_column].iloc[0]
687
 
688
  print("Returning results")
search_funcs/helper_functions.py CHANGED
@@ -9,6 +9,8 @@ import gzip
9
  import pickle
10
  import numpy as np
11
 
 
 
12
  # Openpyxl functions for output
13
  from openpyxl import Workbook
14
  from openpyxl.cell.text import InlineFont
@@ -175,15 +177,15 @@ def read_file(filename):
175
 
176
  return file
177
 
178
- def initial_data_load(in_file):
179
  '''
180
- When file is loaded, update the column dropdown choices
181
  '''
182
  new_choices = []
183
  concat_choices = []
184
  index_load = None
185
  embed_load = np.array([])
186
- tokenised_load =[]
187
  out_message = ""
188
  current_source = ""
189
  df = pd.DataFrame()
@@ -257,7 +259,7 @@ def initial_data_load(in_file):
257
 
258
  return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
259
 
260
- def put_columns_in_join_df(in_file):
261
  '''
262
  When file is loaded, update the column dropdown choices
263
  '''
@@ -354,7 +356,20 @@ def highlight_found_text(search_text: str, full_text: str) -> str:
354
 
355
  return "".join(pos_tokens), combined_positions
356
 
357
- def create_rich_text_cell_from_positions(full_text, combined_positions):
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  # Construct pos_tokens
359
  red = InlineFont(color='00FF0000')
360
  rich_text_cell = CellRichText()
@@ -369,7 +384,21 @@ def create_rich_text_cell_from_positions(full_text, combined_positions):
369
 
370
  return rich_text_cell
371
 
372
- def create_highlighted_excel_wb(df, search_text, column_to_highlight):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  # Create a new Excel workbook
375
  wb = Workbook()
 
9
  import pickle
10
  import numpy as np
11
 
12
+ from typing import List
13
+
14
  # Openpyxl functions for output
15
  from openpyxl import Workbook
16
  from openpyxl.cell.text import InlineFont
 
177
 
178
  return file
179
 
180
+ def initial_data_load(in_file:List[str]):
181
  '''
182
+ When file is loaded, update the column dropdown choices and relevant state variables
183
  '''
184
  new_choices = []
185
  concat_choices = []
186
  index_load = None
187
  embed_load = np.array([])
188
+ tokenised_load = []
189
  out_message = ""
190
  current_source = ""
191
  df = pd.DataFrame()
 
259
 
260
  return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
261
 
262
+ def put_columns_in_join_df(in_file:str):
263
  '''
264
  When file is loaded, update the column dropdown choices
265
  '''
 
356
 
357
  return "".join(pos_tokens), combined_positions
358
 
359
+ def create_rich_text_cell_from_positions(full_text: str, combined_positions: list[tuple[int, int]]) -> CellRichText:
360
+ """
361
+ Create a rich text cell with highlighted positions.
362
+
363
+ This function takes the full text and a list of combined positions, and creates a rich text cell
364
+ with the specified positions highlighted in red.
365
+
366
+ Parameters:
367
+ full_text (str): The full text to be processed.
368
+ combined_positions (list[tuple[int, int]]): A list of tuples representing the start and end positions to be highlighted.
369
+
370
+ Returns:
371
+ CellRichText: The created rich text cell with highlighted positions.
372
+ """
373
  # Construct pos_tokens
374
  red = InlineFont(color='00FF0000')
375
  rich_text_cell = CellRichText()
 
384
 
385
  return rich_text_cell
386
 
387
+ def create_highlighted_excel_wb(df: pd.DataFrame, search_text: str, column_to_highlight: str) -> Workbook:
388
+ """
389
+ Create a new Excel workbook with highlighted search text.
390
+
391
+ This function takes a DataFrame, a search text, and a column name to highlight. It creates a new Excel workbook,
392
+ highlights the occurrences of the search text in the specified column, and returns the workbook.
393
+
394
+ Parameters:
395
+ df (pd.DataFrame): The DataFrame containing the data to be written to the Excel workbook.
396
+ search_text (str): The text to search for and highlight in the specified column.
397
+ column_to_highlight (str): The name of the column in which to highlight the search text.
398
+
399
+ Returns:
400
+ Workbook: The created Excel workbook with highlighted search text.
401
+ """
402
 
403
  # Create a new Excel workbook
404
  wb = Workbook()
search_funcs/semantic_functions.py CHANGED
@@ -5,11 +5,10 @@ from typing import Type
5
  import gradio as gr
6
  import numpy as np
7
  from datetime import datetime
8
- #from transformers import AutoModel, AutoTokenizer
9
- from search_funcs.helper_functions import get_file_path_end
10
- #import torch
11
- from torch import cuda, backends#, tensor, mm, utils
12
  from sentence_transformers import SentenceTransformer
 
13
 
14
  today_rev = datetime.now().strftime("%Y%m%d")
15
 
@@ -25,22 +24,6 @@ else:
25
 
26
  print("Device used is: ", torch_device)
27
 
28
- from search_funcs.helper_functions import create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
29
-
30
- PandasDataFrame = Type[pd.DataFrame]
31
-
32
- # Load embeddings - Jina - deprecated
33
- # Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
34
- # Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
35
- # embeddings_name = "jinaai/jina-embeddings-v2-small-en"
36
- # local_embeddings_location = "model/jina/"
37
- # revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
38
-
39
- # try:
40
- # embeddings_model = AutoModel.from_pretrained(local_embeddings_location, revision = revision_choice, trust_remote_code=True,local_files_only=True, device_map="auto")
41
- # except:
42
- # embeddings_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True, device_map="auto")
43
-
44
  # Load embeddings
45
  embeddings_name = "BAAI/bge-small-en-v1.5"
46
 
@@ -65,32 +48,53 @@ else:
65
  embeddings_model = SentenceTransformer(embeddings_name)
66
  print("Could not find local model installation. Downloading from Huggingface")
67
 
68
- def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_state, clean, return_intermediate_files = "No", embeddings_super_compress = "No", embeddings_model = embeddings_model, progress=gr.Progress(track_tqdm=True)):
69
- '''
70
- Takes a Langchain document class and saves it into a Numpy array.
71
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  ensure_output_folder_exists(output_folder)
74
 
75
  if not in_file:
76
  out_message = "No input file found. Please load in at least one file."
77
  print(out_message)
78
- return out_message, None, None, output_file_state
79
-
80
 
81
  progress(0.6, desc = "Loading/creating embeddings")
82
 
83
  print(f"> Total split documents: {len(docs_out)}")
84
 
85
- #print(docs_out)
86
-
87
  page_contents = [doc.page_content for doc in docs_out]
88
 
89
  ## Load in pre-embedded file if exists
90
  file_list = [string.name for string in in_file]
91
 
92
- #print(file_list)
93
-
94
  embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
95
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
96
  data_file_name = data_file_names[0]
@@ -98,22 +102,12 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
98
 
99
  out_message = "Document processing complete. Ready to search."
100
 
101
- # print("embeddings loaded: ", embeddings_out)
102
 
103
  if embeddings_state.size == 0:
104
  tic = time.perf_counter()
105
  print("Starting to embed documents.")
106
- #embeddings_list = []
107
- #for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
108
- # embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
109
-
110
-
111
-
112
- #embeddings_out = calc_bge_norm_embeddings(page_contents, embeddings_model, tokenizer)
113
 
114
  embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
115
- #embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
116
- #embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
117
 
118
  toc = time.perf_counter()
119
  time_out = f"The embedding took {toc - tic:0.1f} seconds"
@@ -147,31 +141,43 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
147
 
148
  return out_message, embeddings_out, output_file_state, output_file_state
149
 
150
- def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column, progress = gr.Progress(track_tqdm=True)):
151
-
152
- def create_docs_keep_from_df(df):
153
- dict_out = {'ids' : [df['ids']],
154
- 'documents': [df['documents']],
155
- 'metadatas': [df['metadatas']],
156
- 'distances': [round(df['distances'].astype(float), 4)],
157
- 'embeddings': None
158
- }
159
- return dict_out
160
-
161
- # Prepare the DataFrame by transposing
162
- #df_docs = df#.apply(lambda x: x.explode()).reset_index(drop=True)
163
-
164
- # Keep only documents with a certain score
165
-
166
- #print(df_docs)
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  docs_scores = df_docs["distances"] #.astype(float)
169
 
170
  # Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
171
  score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
172
- #docs_keep = create_docs_keep_from_df(score_more_limit) #list(compress(docs, score_more_limit))
173
-
174
- #print(docs_keep)
175
 
176
  if score_more_limit.empty:
177
  return pd.DataFrame()
@@ -179,26 +185,17 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
179
  # Only keep sources that are at least 100 characters long
180
  docs_len = score_more_limit["documents"].str.len() >= 100
181
 
182
- #print(docs_len)
183
-
184
  length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
185
- #docs_keep = create_docs_keep_from_df(length_more_limit) #list(compress(docs_keep, length_more_limit))
186
-
187
- #print(length_more_limit)
188
 
189
  if length_more_limit.empty:
190
  return pd.DataFrame()
191
 
192
  length_more_limit['ids'] = length_more_limit['ids'].astype(int)
193
 
194
- #length_more_limit.to_csv("length_more_limit.csv", index = None)
195
 
196
  # Explode the 'metadatas' dictionary into separate columns
197
  df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
198
 
199
- #print(length_more_limit)
200
- #print(df_metadata_expanded)
201
-
202
  # Concatenate the original DataFrame with the expanded metadata DataFrame
203
  results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
204
 
@@ -208,9 +205,6 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
208
  results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
209
 
210
 
211
- # Join back to original df
212
- # results_df_out = orig_df.merge(length_more_limit[['ids', 'distances']], left_index = True, right_on = "ids", how="inner").sort_values("distances")
213
-
214
  # Join on additional files
215
  if not in_join_file.empty:
216
  progress(0.5, desc = "Joining on additional data file")
@@ -227,68 +221,73 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
227
 
228
  return results_df_out
229
 
230
- def bge_simple_retrieval(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
231
- vec_score_cut_off:float, vec_weight:float, in_join_file, in_join_column = None, search_df_join_column = None, device = torch_device, embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)): # ,vectorstore, embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # print("vectorstore loaded: ", vectorstore)
234
  progress(0, desc = "Conducting semantic search")
235
 
236
  ensure_output_folder_exists(output_folder)
237
 
238
  print("Searching")
239
 
240
- # Convert it to a PyTorch tensor and transfer to GPU
241
- #vectorstore_tensor = tensor(vectorstore).to(device)
242
-
243
  # Load the sentence transformer model and move it to GPU
244
- embeddings = embeddings.to(device)
245
 
246
  # Encode the query using the sentence transformer and convert to a PyTorch tensor
247
- query = embeddings.encode(query_str, normalize_embeddings=True)
248
-
249
- # query = calc_bge_norm_embeddings(query_str, embeddings_model=embeddings_model, tokenizer=tokenizer)
250
-
251
- #query_tensor = tensor(query).to(device)
252
-
253
- # if query_tensor.dim() == 1:
254
- # query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
255
 
256
  # Sentence transformers method, not used:
257
- cosine_similarities = query @ vectorstore.T
258
- #cosine_similarities = util.cos_sim(query_tensor, vectorstore_tensor)[0]
259
- #top_results = torch.topk(cos_scores, k=top_k)
260
-
261
-
262
- # Normalize the query tensor and vectorstore tensor
263
- #query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
264
- #vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
265
-
266
- # Calculate cosine similarities (batch processing)
267
- #cosine_similarities = mm(query_norm, vectorstore_norm.T)
268
- #cosine_similarities = mm(query_tensor, vectorstore_tensor.T)
269
 
270
  # Flatten the tensor to a 1D array
271
  cosine_similarities = cosine_similarities.flatten()
272
 
273
- # Convert to a NumPy array if it's still a PyTorch tensor
274
- #cosine_similarities = cosine_similarities.cpu().numpy()
275
-
276
  # Create a Pandas Series
277
  cosine_similarities_series = pd.Series(cosine_similarities)
278
 
279
- # Pull out relevent info from docs
280
- page_contents = [doc.page_content for doc in docs]
281
- page_meta = [doc.metadata for doc in docs]
282
  ids_range = range(0,len(page_contents))
283
  ids = [str(element) for element in ids_range]
284
 
285
- df_docs = pd.DataFrame(data={"ids": ids,
286
  "documents": page_contents,
287
  "metadatas":page_meta,
288
  "distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
289
 
290
 
291
- results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
292
 
293
  print("Search complete")
294
 
@@ -312,291 +311,4 @@ def bge_simple_retrieval(query_str:str, vectorstore, docs, orig_df_col:str, k_va
312
 
313
  print("Returning results")
314
 
315
- return results_first_text, results_df_name
316
-
317
-
318
- def docs_to_jina_embed_np_array_deprecated(docs_out, in_file, embeddings_state, return_intermediate_files = "No", embeddings_super_compress = "No", embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)):
319
- '''
320
- Takes a Langchain document class and saves it into a Chroma sqlite file.
321
- '''
322
- if not in_file:
323
- out_message = "No input file found. Please load in at least one file."
324
- print(out_message)
325
- return out_message, None, None
326
-
327
-
328
- progress(0.6, desc = "Loading/creating embeddings")
329
-
330
- print(f"> Total split documents: {len(docs_out)}")
331
-
332
- #print(docs_out)
333
-
334
- page_contents = [doc.page_content for doc in docs_out]
335
-
336
- ## Load in pre-embedded file if exists
337
- file_list = [string.name for string in in_file]
338
-
339
- #print(file_list)
340
-
341
- embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
342
- data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
343
- data_file_name = data_file_names[0]
344
- data_file_name_no_ext = get_file_path_end(data_file_name)
345
-
346
- out_message = "Document processing complete. Ready to search."
347
-
348
- # print("embeddings loaded: ", embeddings_out)
349
-
350
- if embeddings_state.size == 0:
351
- tic = time.perf_counter()
352
- print("Starting to embed documents.")
353
- #embeddings_list = []
354
- #for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
355
- # embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
356
-
357
- embeddings_out = embeddings.encode(sentences=page_contents, max_length=1024, show_progress_bar = True, batch_size = 32) # For Jina embeddings
358
- #embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
359
- #embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
360
-
361
- toc = time.perf_counter()
362
- time_out = f"The embedding took {toc - tic:0.1f} seconds"
363
- print(time_out)
364
-
365
- # If you want to save your files for next time
366
- if return_intermediate_files == "Yes":
367
- progress(0.9, desc = "Saving embeddings to file")
368
- if embeddings_super_compress == "No":
369
- semantic_search_file_name = data_file_name_no_ext + '_' + 'embeddings.npz'
370
- np.savez_compressed(semantic_search_file_name, embeddings_out)
371
- else:
372
- semantic_search_file_name = data_file_name_no_ext + '_' + 'embedding_compress.npz'
373
- embeddings_out_round = np.round(embeddings_out, 3)
374
- embeddings_out_round *= 100 # Rounding not currently used
375
- np.savez_compressed(semantic_search_file_name, embeddings_out_round)
376
-
377
- return out_message, embeddings_out, semantic_search_file_name
378
-
379
- return out_message, embeddings_out, None
380
- else:
381
- # Just return existing embeddings if already exist
382
- embeddings_out = embeddings_state
383
-
384
- print(out_message)
385
-
386
- return out_message, embeddings_out, None#, None
387
-
388
- def jina_simple_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
389
- vec_score_cut_off:float, vec_weight:float, in_join_file, in_join_column = None, search_df_join_column = None, device = torch_device, embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)): # ,vectorstore, embeddings
390
-
391
- # print("vectorstore loaded: ", vectorstore)
392
- progress(0, desc = "Conducting semantic search")
393
-
394
- print("Searching")
395
-
396
- # Convert it to a PyTorch tensor and transfer to GPU
397
- vectorstore_tensor = tensor(vectorstore).to(device)
398
-
399
- # Load the sentence transformer model and move it to GPU
400
- embeddings = embeddings.to(device)
401
-
402
- # Encode the query using the sentence transformer and convert to a PyTorch tensor
403
- query = embeddings.encode(query_str)
404
- query_tensor = tensor(query).to(device)
405
-
406
- if query_tensor.dim() == 1:
407
- query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
408
-
409
- # Normalize the query tensor and vectorstore tensor
410
- query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
411
- vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
412
-
413
- # Calculate cosine similarities (batch processing)
414
- cosine_similarities = mm(query_norm, vectorstore_norm.T)
415
-
416
- # Flatten the tensor to a 1D array
417
- cosine_similarities = cosine_similarities.flatten()
418
-
419
- # Convert to a NumPy array if it's still a PyTorch tensor
420
- cosine_similarities = cosine_similarities.cpu().numpy()
421
-
422
- # Create a Pandas Series
423
- cosine_similarities_series = pd.Series(cosine_similarities)
424
-
425
- # Pull out relevent info from docs
426
- page_contents = [doc.page_content for doc in docs]
427
- page_meta = [doc.metadata for doc in docs]
428
- ids_range = range(0,len(page_contents))
429
- ids = [str(element) for element in ids_range]
430
-
431
- df_docs = pd.DataFrame(data={"ids": ids,
432
- "documents": page_contents,
433
- "metadatas":page_meta,
434
- "distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
435
-
436
-
437
- results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
438
-
439
- print("Search complete")
440
-
441
- # If nothing found, return error message
442
- if results_df_out.empty:
443
- return 'No result found!', None
444
-
445
- query_str_file = query_str.replace(" ", "_")
446
-
447
- results_df_name = "semantic_search_result_" + today_rev + "_" + query_str_file + ".xlsx"
448
-
449
- print("Saving search output to file")
450
- progress(0.7, desc = "Saving search output to file")
451
-
452
- results_df_out.to_excel(results_df_name, index= None)
453
- results_first_text = results_df_out.iloc[0, 1]
454
-
455
- print("Returning results")
456
-
457
- return results_first_text, results_df_name
458
-
459
- # Deprecated Chroma functions - kept just in case needed in future.
460
- # Chroma support is currently deprecated
461
- # Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
462
- #import chromadb
463
- #from chromadb.config import Settings
464
- #from typing_extensions import Protocol
465
- #from chromadb import Documents, EmbeddingFunction, Embeddings
466
-
467
- # Remove Chroma database file. If it exists as it can cause issues
468
- #chromadb_file = "chroma.sqlite3"
469
-
470
- #if os.path.isfile(chromadb_file):
471
- # os.remove(chromadb_file)
472
-
473
-
474
- def docs_to_chroma_save_deprecated(docs_out, embeddings = embeddings_model, progress=gr.Progress()):
475
- '''
476
- Takes a Langchain document class and saves it into a Chroma sqlite file. Not currently used.
477
- '''
478
-
479
- print(f"> Total split documents: {len(docs_out)}")
480
-
481
- #print(docs_out)
482
-
483
- page_contents = [doc.page_content for doc in docs_out]
484
- page_meta = [doc.metadata for doc in docs_out]
485
- ids_range = range(0,len(page_contents))
486
- ids = [str(element) for element in ids_range]
487
-
488
- tic = time.perf_counter()
489
- #embeddings_list = []
490
- #for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
491
- # embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
492
-
493
- embeddings_list = embeddings.encode(sentences=page_contents, max_length=256, show_progress_bar = True, batch_size = 32).tolist() # For Jina embeddings
494
- #embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
495
- #embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
496
-
497
- toc = time.perf_counter()
498
- time_out = f"The embedding took {toc - tic:0.1f} seconds"
499
-
500
- #pd.Series(embeddings_list).to_csv("embeddings_out.csv")
501
-
502
- # Jina tiny
503
- # This takes about 300 seconds for 240,000 records = 800 / second, 1024 max length
504
- # For 50k records:
505
- # 61 seconds at 1024 max length
506
- # 55 seconds at 512 max length
507
- # 43 seconds at 256 max length
508
- # 31 seconds at 128 max length
509
-
510
- # The embedding took 1372.5 seconds at 256 max length for 655,020 case notes
511
-
512
- # BGE small
513
- # 96 seconds for 50k records at 512 length
514
-
515
- # all-MiniLM-L6-v2
516
- # 42.5 seconds at (256?) max length
517
-
518
- # paraphrase-MiniLM-L3-v2
519
- # 22 seconds for 128 max length
520
-
521
-
522
- print(time_out)
523
-
524
- chroma_tic = time.perf_counter()
525
-
526
- # Create a new Chroma collection to store the documents and metadata. We don't need to specify an embedding fuction, and the default will be used.
527
- client = chromadb.PersistentClient(path="./last_year", settings=Settings(
528
- anonymized_telemetry=False))
529
-
530
- try:
531
- print("Deleting existing collection.")
532
- #collection = client.get_collection(name="my_collection")
533
- client.delete_collection(name="my_collection")
534
- print("Creating new collection.")
535
- collection = client.create_collection(name="my_collection")
536
- except:
537
- print("Creating new collection.")
538
- collection = client.create_collection(name="my_collection")
539
-
540
- # Match batch size is about 40,000, so add that amount in a loop
541
- def create_batch_ranges(in_list, batch_size=40000):
542
- total_rows = len(in_list)
543
- ranges = []
544
-
545
- for start in range(0, total_rows, batch_size):
546
- end = min(start + batch_size, total_rows)
547
- ranges.append(range(start, end))
548
-
549
- return ranges
550
-
551
- batch_ranges = create_batch_ranges(embeddings_list)
552
- print(batch_ranges)
553
-
554
- for row_range in progress.tqdm(batch_ranges, desc = "Creating vector database", unit = "batches of 40,000 rows"):
555
-
556
- collection.add(
557
- documents = page_contents[row_range[0]:row_range[-1]],
558
- embeddings = embeddings_list[row_range[0]:row_range[-1]],
559
- metadatas = page_meta[row_range[0]:row_range[-1]],
560
- ids = ids[row_range[0]:row_range[-1]])
561
- #print("Here")
562
-
563
- # print(collection.count())
564
-
565
-
566
- #chatf.vectorstore = vectorstore_func
567
-
568
- chroma_toc = time.perf_counter()
569
-
570
- chroma_time_out = f"Loading to Chroma db took {chroma_toc - chroma_tic:0.1f} seconds"
571
- print(chroma_time_out)
572
-
573
- out_message = "Document processing complete"
574
-
575
- return out_message, collection
576
-
577
- def chroma_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
578
- vec_score_cut_off:float, vec_weight:float, in_join_file = None, in_join_column = None, search_df_join_column = None, embeddings = embeddings_model): # ,vectorstore, embeddings
579
-
580
- query = embeddings.encode(query_str).tolist()
581
-
582
- docs = vectorstore.query(
583
- query_embeddings=query,
584
- n_results= k_val # No practical limit on number of responses returned
585
- #where={"metadata_field": "is_equal_to_this"},
586
- #where_document={"$contains":"search_string"}
587
- )
588
-
589
- df_docs = pd.DataFrame(data={'ids': docs['ids'][0],
590
- 'documents': docs['documents'][0],
591
- 'metadatas':docs['metadatas'][0],
592
- 'distances':docs['distances'][0]#,
593
- #'embeddings': docs['embeddings']
594
- })
595
-
596
- results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
597
-
598
- results_df_name = output_folder + "semantic_search_result.csv"
599
- results_df_out.to_csv(results_df_name, index= None)
600
- results_first_text = results_df_out[orig_df_col].iloc[0]
601
-
602
- return results_first_text, results_df_name
 
5
  import gradio as gr
6
  import numpy as np
7
  from datetime import datetime
8
+ from search_funcs.helper_functions import get_file_path_end, create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
9
+ from torch import cuda, backends
 
 
10
  from sentence_transformers import SentenceTransformer
11
+ PandasDataFrame = Type[pd.DataFrame]
12
 
13
  today_rev = datetime.now().strftime("%Y%m%d")
14
 
 
24
 
25
  print("Device used is: ", torch_device)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Load embeddings
28
  embeddings_name = "BAAI/bge-small-en-v1.5"
29
 
 
48
  embeddings_model = SentenceTransformer(embeddings_name)
49
  print("Could not find local model installation. Downloading from Huggingface")
50
 
51
+
52
+ def docs_to_bge_embed_np_array(
53
+ docs_out: list,
54
+ in_file: list,
55
+ embeddings_state: np.ndarray,
56
+ output_file_state: str,
57
+ clean: str,
58
+ return_intermediate_files: str = "No",
59
+ embeddings_super_compress: str = "No",
60
+ embeddings_model: SentenceTransformer = embeddings_model,
61
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
62
+ ) -> tuple:
63
+ """
64
+ Process documents to create BGE embeddings and save them as a numpy array.
65
+
66
+ Parameters:
67
+ - docs_out (list): List of documents to be embedded.
68
+ - in_file (list): List of input files.
69
+ - embeddings_state (np.ndarray): Current state of embeddings.
70
+ - output_file_state (str): State of the output file.
71
+ - clean (str): Indicates if the data should be cleaned.
72
+ - return_intermediate_files (str, optional): Whether to return intermediate files. Default is "No".
73
+ - embeddings_super_compress (str, optional): Whether to super compress the embeddings. Default is "No".
74
+ - embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
75
+ - progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
76
+
77
+ Returns:
78
+ - tuple: A tuple containing the output message, embeddings, and output file state.
79
+ """
80
+
81
 
82
  ensure_output_folder_exists(output_folder)
83
 
84
  if not in_file:
85
  out_message = "No input file found. Please load in at least one file."
86
  print(out_message)
87
+ return out_message, None, None, output_file_state
 
88
 
89
  progress(0.6, desc = "Loading/creating embeddings")
90
 
91
  print(f"> Total split documents: {len(docs_out)}")
92
 
 
 
93
  page_contents = [doc.page_content for doc in docs_out]
94
 
95
  ## Load in pre-embedded file if exists
96
  file_list = [string.name for string in in_file]
97
 
 
 
98
  embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
99
  data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
100
  data_file_name = data_file_names[0]
 
102
 
103
  out_message = "Document processing complete. Ready to search."
104
 
 
105
 
106
  if embeddings_state.size == 0:
107
  tic = time.perf_counter()
108
  print("Starting to embed documents.")
 
 
 
 
 
 
 
109
 
110
  embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
 
 
111
 
112
  toc = time.perf_counter()
113
  time_out = f"The embedding took {toc - tic:0.1f} seconds"
 
141
 
142
  return out_message, embeddings_out, output_file_state, output_file_state
143
 
144
+ def process_data_from_scores_df(
145
+ df_docs: pd.DataFrame,
146
+ in_join_file: pd.DataFrame,
147
+ vec_score_cut_off: float,
148
+ in_join_column: str,
149
+ search_df_join_column: str,
150
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
151
+ ) -> pd.DataFrame:
152
+ """
153
+ Process the data from the scores DataFrame by filtering based on score cutoff and document length,
154
+ and optionally joining with an additional file.
155
+
156
+ Parameters
157
+ ----------
158
+ df_docs : pd.DataFrame
159
+ DataFrame containing document scores and metadata.
160
+ in_join_file : pd.DataFrame
161
+ DataFrame to join with the results based on specified columns.
162
+ vec_score_cut_off : float
163
+ Cutoff value for the vector similarity score.
164
+ in_join_column : str
165
+ Column name in the join file to join on.
166
+ search_df_join_column : str
167
+ Column name in the search DataFrame to join on.
168
+ progress : gr.Progress, optional
169
+ Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
170
+
171
+ Returns
172
+ -------
173
+ pd.DataFrame
174
+ Processed DataFrame with filtered and joined data.
175
+ """
176
+
177
  docs_scores = df_docs["distances"] #.astype(float)
178
 
179
  # Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
180
  score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
 
 
 
181
 
182
  if score_more_limit.empty:
183
  return pd.DataFrame()
 
185
  # Only keep sources that are at least 100 characters long
186
  docs_len = score_more_limit["documents"].str.len() >= 100
187
 
 
 
188
  length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
 
 
 
189
 
190
  if length_more_limit.empty:
191
  return pd.DataFrame()
192
 
193
  length_more_limit['ids'] = length_more_limit['ids'].astype(int)
194
 
 
195
 
196
  # Explode the 'metadatas' dictionary into separate columns
197
  df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
198
 
 
 
 
199
  # Concatenate the original DataFrame with the expanded metadata DataFrame
200
  results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
201
 
 
205
  results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
206
 
207
 
 
 
 
208
  # Join on additional files
209
  if not in_join_file.empty:
210
  progress(0.5, desc = "Joining on additional data file")
 
221
 
222
  return results_df_out
223
 
224
+ def bge_semantic_search(
225
+ query_str: str,
226
+ embeddings: np.ndarray,
227
+ documents: list,
228
+ k_val: int,
229
+ vec_score_cut_off: float,
230
+ in_join_file: pd.DataFrame,
231
+ in_join_column: str = None,
232
+ search_df_join_column: str = None,
233
+ device: str = torch_device,
234
+ embeddings_model: SentenceTransformer = embeddings_model,
235
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
236
+ ) -> pd.DataFrame:
237
+ """
238
+ Perform a semantic search using the BGE model.
239
+
240
+ Parameters:
241
+ - query_str (str): The query string to search for.
242
+ - embeddings (np.ndarray): The embeddings to search within.
243
+ - documents (list): The list of documents to search.
244
+ - k_val (int): The number of top results to return.
245
+ - vec_score_cut_off (float): The score cutoff for filtering results.
246
+ - in_join_file (pd.DataFrame): The DataFrame to join with the search results.
247
+ - in_join_column (str, optional): The column name in the join DataFrame to join on. Default is None.
248
+ - search_df_join_column (str, optional): The column name in the search DataFrame to join on. Default is None.
249
+ - device (str, optional): The device to run the model on. Default is torch_device.
250
+ - embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
251
+ - progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
252
+
253
+ Returns:
254
+ - pd.DataFrame: The DataFrame containing the search results.
255
+ """
256
 
 
257
  progress(0, desc = "Conducting semantic search")
258
 
259
  ensure_output_folder_exists(output_folder)
260
 
261
  print("Searching")
262
 
 
 
 
263
  # Load the sentence transformer model and move it to GPU
264
+ embeddings_model = embeddings_model.to(device)
265
 
266
  # Encode the query using the sentence transformer and convert to a PyTorch tensor
267
+ query = embeddings_model.encode(query_str, normalize_embeddings=True)
 
 
 
 
 
 
 
268
 
269
  # Sentence transformers method, not used:
270
+ cosine_similarities = query @ embeddings.T
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  # Flatten the tensor to a 1D array
273
  cosine_similarities = cosine_similarities.flatten()
274
 
 
 
 
275
  # Create a Pandas Series
276
  cosine_similarities_series = pd.Series(cosine_similarities)
277
 
278
+ # Pull out relevent info from documents
279
+ page_contents = [doc.page_content for doc in documents]
280
+ page_meta = [doc.metadata for doc in documents]
281
  ids_range = range(0,len(page_contents))
282
  ids = [str(element) for element in ids_range]
283
 
284
+ df_documents = pd.DataFrame(data={"ids": ids,
285
  "documents": page_contents,
286
  "metadatas":page_meta,
287
  "distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
288
 
289
 
290
+ results_df_out = process_data_from_scores_df(df_documents, in_join_file, vec_score_cut_off, in_join_column, search_df_join_column)
291
 
292
  print("Search complete")
293
 
 
311
 
312
  print("Returning results")
313
 
314
+ return results_first_text, results_df_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
search_funcs/spacy_search_funcs.py CHANGED
@@ -27,9 +27,14 @@ except:
27
  nlp = spacy.load("en_core_web_sm")
28
  print("Successfully imported spaCy model")
29
 
30
- def spacy_fuzzy_search(string_query:str, df_list: List[str], original_data: PandasDataFrame, text_column:str, in_join_file: PandasDataFrame, search_df_join_column:str, in_join_column:str, no_spelling_mistakes:int = 1, progress=gr.Progress(track_tqdm=True)):
31
  ''' Conduct fuzzy match on a list of data.'''
32
 
 
 
 
 
 
33
  if len(df_list) > 10000:
34
  out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
35
  return out_message, None
 
27
  nlp = spacy.load("en_core_web_sm")
28
  print("Successfully imported spaCy model")
29
 
30
+ def spacy_fuzzy_search(string_query:str, tokenised_data: List[List[str]], original_data: PandasDataFrame, text_column:str, in_join_file: PandasDataFrame, search_df_join_column:str, in_join_column:str, no_spelling_mistakes:int = 1, progress=gr.Progress(track_tqdm=True)):
31
  ''' Conduct fuzzy match on a list of data.'''
32
 
33
+ #print("df_list:", df_list)
34
+
35
+ # Convert tokenised data back into a list of strings
36
+ df_list = list(map(" ".join, tokenised_data))
37
+
38
  if len(df_list) > 10000:
39
  out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
40
  return out_message, None