seanpedrickcase commited on
Commit
2393537
1 Parent(s): 2754a2b

Set bm25 in functions explicitly. Some API updates. Now can get connection params on startup.

Browse files
app.py CHANGED
@@ -8,7 +8,7 @@ PandasDataFrame = Type[pd.DataFrame]
8
  from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
9
  from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
10
  from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_simple_retrieval
11
- from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, output_folder
12
  from search_funcs.spacy_search_funcs import spacy_fuzzy_search
13
  from search_funcs.aws_functions import load_data_from_aws
14
 
@@ -30,6 +30,7 @@ with block:
30
  embeddings_state = gr.State(np.array([])) # globals()["embeddings"]
31
  search_index_state = gr.State()
32
  tokenised_state = gr.State()
 
33
 
34
  k_val = gr.State(9999)
35
  out_passages = gr.State(9999)
@@ -46,6 +47,9 @@ with block:
46
  orig_semantic_data_state = gr.State(pd.DataFrame())
47
  semantic_data_state = gr.State(pd.DataFrame())
48
 
 
 
 
49
  in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
50
  presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
51
  that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
@@ -92,7 +96,6 @@ depends on factors such as the type of documents or queries. Information taken f
92
  output_single_text = gr.Textbox(label="Top result")
93
  output_file = gr.File(label="File output")
94
 
95
-
96
  with gr.Tab("Semantic search"):
97
  gr.Markdown(
98
  """
@@ -179,20 +182,20 @@ depends on factors such as the type of documents or queries. Information taken f
179
 
180
  ### BM25 SEARCH ###
181
  # Update dropdowns upon initial file load
182
- in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, keyword_data_state, orig_keyword_data_state, search_index_state, embeddings_state, tokenised_state, load_finished_message, current_source])
183
  in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
184
 
185
  # Load in BM25 data
186
  load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, keyword_data_state, tokenised_state, in_clean_data, return_intermediate_files], outputs=[corpus_state, load_finished_message, keyword_data_state, output_file, output_file, keyword_data_list_state, in_bm25_column], api_name="load_keyword").\
187
- then(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file], api_name="prepare_keyword")#.\
188
 
189
 
190
  # BM25 search functions on click or enter
191
- keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, in_join_column, search_df_join_column], outputs=[output_single_text, output_file], api_name="keyword_search")
192
- keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, in_join_column, search_df_join_column], outputs=[output_single_text, output_file])
193
 
194
  # Fuzzy search functions on click
195
- fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, keyword_data_list_state, keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy")
196
 
197
  ### SEMANTIC SEARCH ###
198
 
@@ -203,30 +206,15 @@ depends on factors such as the type of documents or queries. Information taken f
203
  then(docs_to_bge_embed_np_array, inputs=[ingest_docs, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, vectorstore_state, semantic_output_file, output_file_state])
204
 
205
  # Semantic search query
206
- semantic_submit.click(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic")
207
  semantic_query.submit(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
208
 
209
- # Simple run for HF spaces or local on your computer
210
- #block.queue().launch(debug=True)
211
-
212
- # def get_params(request: gr.Request):
213
- # if request:
214
- # print("Request headers dictionary:", request.headers)
215
- # print("IP address:", request.client.host)
216
- # print("Query parameters:", dict(request.query_params))
217
- # return request.query_params
218
-
219
- # request_params = get_params()
220
- # print(request_params)
221
 
222
- # Running on server (e.g. AWS) without specifying port
223
- block.queue().launch(ssl_verify=False) # root_path="/data-text-search" # server_name="0.0.0.0",
224
-
225
-
226
- # Running on local server without https
227
- #block.queue().launch(server_name="0.0.0.0", server_port=7861, ssl_verify=False)
228
 
229
  # Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
230
  # block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
231
- # ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
232
-
 
8
  from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
9
  from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
10
  from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_simple_retrieval
11
+ from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
12
  from search_funcs.spacy_search_funcs import spacy_fuzzy_search
13
  from search_funcs.aws_functions import load_data_from_aws
14
 
 
30
  embeddings_state = gr.State(np.array([])) # globals()["embeddings"]
31
  search_index_state = gr.State()
32
  tokenised_state = gr.State()
33
+ bm25_search_object_state = gr.State()
34
 
35
  k_val = gr.State(9999)
36
  out_passages = gr.State(9999)
 
47
  orig_semantic_data_state = gr.State(pd.DataFrame())
48
  semantic_data_state = gr.State(pd.DataFrame())
49
 
50
+ session_hash_state = gr.State("")
51
+ s3_output_folder_state = gr.State("")
52
+
53
  in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
54
  presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
55
  that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
 
96
  output_single_text = gr.Textbox(label="Top result")
97
  output_file = gr.File(label="File output")
98
 
 
99
  with gr.Tab("Semantic search"):
100
  gr.Markdown(
101
  """
 
182
 
183
  ### BM25 SEARCH ###
184
  # Update dropdowns upon initial file load
185
+ in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, keyword_data_state, orig_keyword_data_state, search_index_state, embeddings_state, tokenised_state, load_finished_message, current_source], api_name="initial_load")
186
  in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
187
 
188
  # Load in BM25 data
189
  load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, keyword_data_state, tokenised_state, in_clean_data, return_intermediate_files], outputs=[corpus_state, load_finished_message, keyword_data_state, output_file, output_file, keyword_data_list_state, in_bm25_column], api_name="load_keyword").\
190
+ then(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_object_state], api_name="prepare_keyword")
191
 
192
 
193
  # BM25 search functions on click or enter
194
+ keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file], api_name="keyword_search")
195
+ keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file])
196
 
197
  # Fuzzy search functions on click
198
+ fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, keyword_data_list_state, keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
199
 
200
  ### SEMANTIC SEARCH ###
201
 
 
206
  then(docs_to_bge_embed_np_array, inputs=[ingest_docs, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, vectorstore_state, semantic_output_file, output_file_state])
207
 
208
  # Semantic search query
209
+ semantic_submit.click(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
210
  semantic_query.submit(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
211
 
212
+ block.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # Launch the Gradio app
215
+ if __name__ == "__main__":
216
+ block.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
 
 
 
217
 
218
  # Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
219
  # block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
220
+ # ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
 
search_funcs/bm25_functions.py CHANGED
@@ -40,6 +40,7 @@ tokenizer = nlp.tokenizer
40
  PARAM_K1 = 1.5
41
  PARAM_B = 0.75
42
  IDF_CUTOFF = -inf
 
43
 
44
  # Class built off https://github.com/Inspirateur/Fast-BM25
45
 
@@ -263,6 +264,8 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
263
  tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
264
  search_index_file_names = [string for string in file_list if "gz" in string.lower()]
265
 
 
 
266
  df[text_column] = df[text_column].astype(str).str.lower()
267
 
268
  if "copy_of_case_note_id" in df.columns:
@@ -386,8 +389,6 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
386
  print(out_message)
387
  return out_message, None
388
 
389
-
390
-
391
  file_list = [string.name for string in in_file]
392
 
393
  #print(file_list)
@@ -444,13 +445,13 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
444
 
445
  message = "Search parameters loaded."
446
 
447
- return message, bm25_search_file_name
448
 
449
  message = "Search parameters loaded."
450
 
451
  print(message)
452
 
453
- return message, None
454
 
455
  def convert_bm25_query_to_tokens(free_text_query, clean="No"):
456
  '''
@@ -473,7 +474,7 @@ def convert_bm25_query_to_tokens(free_text_query, clean="No"):
473
 
474
  return out_query
475
 
476
- def bm25_search(free_text_query, in_no_search_results, original_data, searched_data, text_column, in_join_file, clean, in_join_column = "", search_df_join_column = "", progress=gr.Progress(track_tqdm=True)):
477
 
478
  progress(0, desc = "Conducting keyword search")
479
 
 
40
  PARAM_K1 = 1.5
41
  PARAM_B = 0.75
42
  IDF_CUTOFF = -inf
43
+ bm25 = "" # Placeholder just so initial load doesn't fail
44
 
45
  # Class built off https://github.com/Inspirateur/Fast-BM25
46
 
 
264
  tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
265
  search_index_file_names = [string for string in file_list if "gz" in string.lower()]
266
 
267
+ print("Dataframe columns:", df.columns)
268
+
269
  df[text_column] = df[text_column].astype(str).str.lower()
270
 
271
  if "copy_of_case_note_id" in df.columns:
 
389
  print(out_message)
390
  return out_message, None
391
 
 
 
392
  file_list = [string.name for string in in_file]
393
 
394
  #print(file_list)
 
445
 
446
  message = "Search parameters loaded."
447
 
448
+ return message, bm25_search_file_name, bm25
449
 
450
  message = "Search parameters loaded."
451
 
452
  print(message)
453
 
454
+ return message, None, bm25
455
 
456
  def convert_bm25_query_to_tokens(free_text_query, clean="No"):
457
  '''
 
474
 
475
  return out_query
476
 
477
+ def bm25_search(free_text_query, in_no_search_results, original_data, searched_data, text_column, in_join_file, clean, bm25, in_join_column = "", search_df_join_column = "", progress=gr.Progress(track_tqdm=True)):
478
 
479
  progress(0, desc = "Conducting keyword search")
480
 
search_funcs/helper_functions.py CHANGED
@@ -15,6 +15,8 @@ from openpyxl.cell.text import InlineFont
15
  from openpyxl.cell.rich_text import TextBlock, CellRichText
16
  from openpyxl.styles import Font, Alignment
17
 
 
 
18
  megabyte = 1024 * 1024 # Bytes in a megabyte
19
  file_size_mb = 500 # Size in megabytes
20
  file_size_bytes_500mb = megabyte * file_size_mb
@@ -49,6 +51,41 @@ def ensure_output_folder_exists(output_folder):
49
  else:
50
  print(f"The output folder already exists:", folder_name)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Attempt to delete content of gradio temp folder
53
  def get_temp_folder_path():
54
  username = getpass.getuser()
 
15
  from openpyxl.cell.rich_text import TextBlock, CellRichText
16
  from openpyxl.styles import Font, Alignment
17
 
18
+ from search_funcs.aws_functions import bucket_name
19
+
20
  megabyte = 1024 * 1024 # Bytes in a megabyte
21
  file_size_mb = 500 # Size in megabytes
22
  file_size_bytes_500mb = megabyte * file_size_mb
 
51
  else:
52
  print(f"The output folder already exists:", folder_name)
53
 
54
+ def get_connection_params(request: gr.Request):
55
+ if request:
56
+ #request_data = request.json() # Parse JSON body
57
+ #print("All request data:", request_data)
58
+ #context_value = request_data.get('context')
59
+ #if 'context' in request_data:
60
+ # print("Request context dictionary:", request_data['context'])
61
+
62
+ #print("Request headers dictionary:", request.headers)
63
+ #print("All host elements", request.client)
64
+ #print("IP address:", request.client.host)
65
+ #print("Query parameters:", dict(request.query_params))
66
+ # To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
67
+ #print("Request dictionary to object:", request.request.body())
68
+ print("Session hash:", request.session_hash)
69
+
70
+ if 'x-cognito-id' in request.headers:
71
+ out_session_hash = request.headers['x-cognito-id']
72
+ base_folder = "user-files/"
73
+ print("Cognito ID found:", out_session_hash)
74
+
75
+ else:
76
+ out_session_hash = request.session_hash
77
+ base_folder = "temp-files/"
78
+ print("Cognito ID not found. Using session hash as save folder.")
79
+
80
+ output_folder = base_folder + out_session_hash + "/"
81
+ if bucket_name:
82
+ print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
83
+
84
+ return out_session_hash, output_folder
85
+ else:
86
+ print("No session parameters found.")
87
+ return "", ""
88
+
89
  # Attempt to delete content of gradio temp folder
90
  def get_temp_folder_path():
91
  username = getpass.getuser()