Spaces:
Sleeping
Sleeping
seanpedrickcase
commited on
Commit
•
2393537
1
Parent(s):
2754a2b
Set bm25 in functions explicitly. Some API updates. Now can get connection params on startup.
Browse files- app.py +16 -28
- search_funcs/bm25_functions.py +6 -5
- search_funcs/helper_functions.py +37 -0
app.py
CHANGED
@@ -8,7 +8,7 @@ PandasDataFrame = Type[pd.DataFrame]
|
|
8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
10 |
from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_simple_retrieval
|
11 |
-
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, output_folder
|
12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
13 |
from search_funcs.aws_functions import load_data_from_aws
|
14 |
|
@@ -30,6 +30,7 @@ with block:
|
|
30 |
embeddings_state = gr.State(np.array([])) # globals()["embeddings"]
|
31 |
search_index_state = gr.State()
|
32 |
tokenised_state = gr.State()
|
|
|
33 |
|
34 |
k_val = gr.State(9999)
|
35 |
out_passages = gr.State(9999)
|
@@ -46,6 +47,9 @@ with block:
|
|
46 |
orig_semantic_data_state = gr.State(pd.DataFrame())
|
47 |
semantic_data_state = gr.State(pd.DataFrame())
|
48 |
|
|
|
|
|
|
|
49 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
50 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
51 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
@@ -92,7 +96,6 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
92 |
output_single_text = gr.Textbox(label="Top result")
|
93 |
output_file = gr.File(label="File output")
|
94 |
|
95 |
-
|
96 |
with gr.Tab("Semantic search"):
|
97 |
gr.Markdown(
|
98 |
"""
|
@@ -179,20 +182,20 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
179 |
|
180 |
### BM25 SEARCH ###
|
181 |
# Update dropdowns upon initial file load
|
182 |
-
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, keyword_data_state, orig_keyword_data_state, search_index_state, embeddings_state, tokenised_state, load_finished_message, current_source])
|
183 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
184 |
|
185 |
# Load in BM25 data
|
186 |
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, keyword_data_state, tokenised_state, in_clean_data, return_intermediate_files], outputs=[corpus_state, load_finished_message, keyword_data_state, output_file, output_file, keyword_data_list_state, in_bm25_column], api_name="load_keyword").\
|
187 |
-
then(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file], api_name="prepare_keyword")
|
188 |
|
189 |
|
190 |
# BM25 search functions on click or enter
|
191 |
-
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, in_join_column, search_df_join_column], outputs=[output_single_text, output_file], api_name="keyword_search")
|
192 |
-
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, in_join_column, search_df_join_column], outputs=[output_single_text, output_file])
|
193 |
|
194 |
# Fuzzy search functions on click
|
195 |
-
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, keyword_data_list_state, keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="
|
196 |
|
197 |
### SEMANTIC SEARCH ###
|
198 |
|
@@ -203,30 +206,15 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
203 |
then(docs_to_bge_embed_np_array, inputs=[ingest_docs, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, vectorstore_state, semantic_output_file, output_file_state])
|
204 |
|
205 |
# Semantic search query
|
206 |
-
semantic_submit.click(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="
|
207 |
semantic_query.submit(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
|
208 |
|
209 |
-
|
210 |
-
#block.queue().launch(debug=True)
|
211 |
-
|
212 |
-
# def get_params(request: gr.Request):
|
213 |
-
# if request:
|
214 |
-
# print("Request headers dictionary:", request.headers)
|
215 |
-
# print("IP address:", request.client.host)
|
216 |
-
# print("Query parameters:", dict(request.query_params))
|
217 |
-
# return request.query_params
|
218 |
-
|
219 |
-
# request_params = get_params()
|
220 |
-
# print(request_params)
|
221 |
|
222 |
-
#
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
# Running on local server without https
|
227 |
-
#block.queue().launch(server_name="0.0.0.0", server_port=7861, ssl_verify=False)
|
228 |
|
229 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
230 |
# block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
|
231 |
-
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
232 |
-
|
|
|
8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
10 |
from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_simple_retrieval
|
11 |
+
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
|
12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
13 |
from search_funcs.aws_functions import load_data_from_aws
|
14 |
|
|
|
30 |
embeddings_state = gr.State(np.array([])) # globals()["embeddings"]
|
31 |
search_index_state = gr.State()
|
32 |
tokenised_state = gr.State()
|
33 |
+
bm25_search_object_state = gr.State()
|
34 |
|
35 |
k_val = gr.State(9999)
|
36 |
out_passages = gr.State(9999)
|
|
|
47 |
orig_semantic_data_state = gr.State(pd.DataFrame())
|
48 |
semantic_data_state = gr.State(pd.DataFrame())
|
49 |
|
50 |
+
session_hash_state = gr.State("")
|
51 |
+
s3_output_folder_state = gr.State("")
|
52 |
+
|
53 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
54 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
55 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
|
|
96 |
output_single_text = gr.Textbox(label="Top result")
|
97 |
output_file = gr.File(label="File output")
|
98 |
|
|
|
99 |
with gr.Tab("Semantic search"):
|
100 |
gr.Markdown(
|
101 |
"""
|
|
|
182 |
|
183 |
### BM25 SEARCH ###
|
184 |
# Update dropdowns upon initial file load
|
185 |
+
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, keyword_data_state, orig_keyword_data_state, search_index_state, embeddings_state, tokenised_state, load_finished_message, current_source], api_name="initial_load")
|
186 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
187 |
|
188 |
# Load in BM25 data
|
189 |
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, keyword_data_state, tokenised_state, in_clean_data, return_intermediate_files], outputs=[corpus_state, load_finished_message, keyword_data_state, output_file, output_file, keyword_data_list_state, in_bm25_column], api_name="load_keyword").\
|
190 |
+
then(fn=prepare_bm25, inputs=[corpus_state, in_bm25_file, in_bm25_column, search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_object_state], api_name="prepare_keyword")
|
191 |
|
192 |
|
193 |
# BM25 search functions on click or enter
|
194 |
+
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file], api_name="keyword_search")
|
195 |
+
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_object_state, in_join_column, search_df_join_column], outputs=[output_single_text, output_file])
|
196 |
|
197 |
# Fuzzy search functions on click
|
198 |
+
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, keyword_data_list_state, keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
|
199 |
|
200 |
### SEMANTIC SEARCH ###
|
201 |
|
|
|
206 |
then(docs_to_bge_embed_np_array, inputs=[ingest_docs, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, vectorstore_state, semantic_output_file, output_file_state])
|
207 |
|
208 |
# Semantic search query
|
209 |
+
semantic_submit.click(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
|
210 |
semantic_query.submit(bge_simple_retrieval, inputs=[semantic_query, vectorstore_state, ingest_docs, in_semantic_column, k_val, out_passages, semantic_min_distance, vec_weight, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
|
211 |
|
212 |
+
block.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
# Launch the Gradio app
|
215 |
+
if __name__ == "__main__":
|
216 |
+
block.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
|
|
|
|
|
|
|
217 |
|
218 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
219 |
# block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
|
220 |
+
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
|
search_funcs/bm25_functions.py
CHANGED
@@ -40,6 +40,7 @@ tokenizer = nlp.tokenizer
|
|
40 |
PARAM_K1 = 1.5
|
41 |
PARAM_B = 0.75
|
42 |
IDF_CUTOFF = -inf
|
|
|
43 |
|
44 |
# Class built off https://github.com/Inspirateur/Fast-BM25
|
45 |
|
@@ -263,6 +264,8 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
263 |
tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
264 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
265 |
|
|
|
|
|
266 |
df[text_column] = df[text_column].astype(str).str.lower()
|
267 |
|
268 |
if "copy_of_case_note_id" in df.columns:
|
@@ -386,8 +389,6 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
386 |
print(out_message)
|
387 |
return out_message, None
|
388 |
|
389 |
-
|
390 |
-
|
391 |
file_list = [string.name for string in in_file]
|
392 |
|
393 |
#print(file_list)
|
@@ -444,13 +445,13 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
444 |
|
445 |
message = "Search parameters loaded."
|
446 |
|
447 |
-
return message, bm25_search_file_name
|
448 |
|
449 |
message = "Search parameters loaded."
|
450 |
|
451 |
print(message)
|
452 |
|
453 |
-
return message, None
|
454 |
|
455 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
456 |
'''
|
@@ -473,7 +474,7 @@ def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
|
473 |
|
474 |
return out_query
|
475 |
|
476 |
-
def bm25_search(free_text_query, in_no_search_results, original_data, searched_data, text_column, in_join_file, clean,
|
477 |
|
478 |
progress(0, desc = "Conducting keyword search")
|
479 |
|
|
|
40 |
PARAM_K1 = 1.5
|
41 |
PARAM_B = 0.75
|
42 |
IDF_CUTOFF = -inf
|
43 |
+
bm25 = "" # Placeholder just so initial load doesn't fail
|
44 |
|
45 |
# Class built off https://github.com/Inspirateur/Fast-BM25
|
46 |
|
|
|
264 |
tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
265 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
266 |
|
267 |
+
print("Dataframe columns:", df.columns)
|
268 |
+
|
269 |
df[text_column] = df[text_column].astype(str).str.lower()
|
270 |
|
271 |
if "copy_of_case_note_id" in df.columns:
|
|
|
389 |
print(out_message)
|
390 |
return out_message, None
|
391 |
|
|
|
|
|
392 |
file_list = [string.name for string in in_file]
|
393 |
|
394 |
#print(file_list)
|
|
|
445 |
|
446 |
message = "Search parameters loaded."
|
447 |
|
448 |
+
return message, bm25_search_file_name, bm25
|
449 |
|
450 |
message = "Search parameters loaded."
|
451 |
|
452 |
print(message)
|
453 |
|
454 |
+
return message, None, bm25
|
455 |
|
456 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
457 |
'''
|
|
|
474 |
|
475 |
return out_query
|
476 |
|
477 |
+
def bm25_search(free_text_query, in_no_search_results, original_data, searched_data, text_column, in_join_file, clean, bm25, in_join_column = "", search_df_join_column = "", progress=gr.Progress(track_tqdm=True)):
|
478 |
|
479 |
progress(0, desc = "Conducting keyword search")
|
480 |
|
search_funcs/helper_functions.py
CHANGED
@@ -15,6 +15,8 @@ from openpyxl.cell.text import InlineFont
|
|
15 |
from openpyxl.cell.rich_text import TextBlock, CellRichText
|
16 |
from openpyxl.styles import Font, Alignment
|
17 |
|
|
|
|
|
18 |
megabyte = 1024 * 1024 # Bytes in a megabyte
|
19 |
file_size_mb = 500 # Size in megabytes
|
20 |
file_size_bytes_500mb = megabyte * file_size_mb
|
@@ -49,6 +51,41 @@ def ensure_output_folder_exists(output_folder):
|
|
49 |
else:
|
50 |
print(f"The output folder already exists:", folder_name)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Attempt to delete content of gradio temp folder
|
53 |
def get_temp_folder_path():
|
54 |
username = getpass.getuser()
|
|
|
15 |
from openpyxl.cell.rich_text import TextBlock, CellRichText
|
16 |
from openpyxl.styles import Font, Alignment
|
17 |
|
18 |
+
from search_funcs.aws_functions import bucket_name
|
19 |
+
|
20 |
megabyte = 1024 * 1024 # Bytes in a megabyte
|
21 |
file_size_mb = 500 # Size in megabytes
|
22 |
file_size_bytes_500mb = megabyte * file_size_mb
|
|
|
51 |
else:
|
52 |
print(f"The output folder already exists:", folder_name)
|
53 |
|
54 |
+
def get_connection_params(request: gr.Request):
|
55 |
+
if request:
|
56 |
+
#request_data = request.json() # Parse JSON body
|
57 |
+
#print("All request data:", request_data)
|
58 |
+
#context_value = request_data.get('context')
|
59 |
+
#if 'context' in request_data:
|
60 |
+
# print("Request context dictionary:", request_data['context'])
|
61 |
+
|
62 |
+
#print("Request headers dictionary:", request.headers)
|
63 |
+
#print("All host elements", request.client)
|
64 |
+
#print("IP address:", request.client.host)
|
65 |
+
#print("Query parameters:", dict(request.query_params))
|
66 |
+
# To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
|
67 |
+
#print("Request dictionary to object:", request.request.body())
|
68 |
+
print("Session hash:", request.session_hash)
|
69 |
+
|
70 |
+
if 'x-cognito-id' in request.headers:
|
71 |
+
out_session_hash = request.headers['x-cognito-id']
|
72 |
+
base_folder = "user-files/"
|
73 |
+
print("Cognito ID found:", out_session_hash)
|
74 |
+
|
75 |
+
else:
|
76 |
+
out_session_hash = request.session_hash
|
77 |
+
base_folder = "temp-files/"
|
78 |
+
print("Cognito ID not found. Using session hash as save folder.")
|
79 |
+
|
80 |
+
output_folder = base_folder + out_session_hash + "/"
|
81 |
+
if bucket_name:
|
82 |
+
print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
|
83 |
+
|
84 |
+
return out_session_hash, output_folder
|
85 |
+
else:
|
86 |
+
print("No session parameters found.")
|
87 |
+
return "", ""
|
88 |
+
|
89 |
# Attempt to delete content of gradio temp folder
|
90 |
def get_temp_folder_path():
|
91 |
username = getpass.getuser()
|