Spaces:
Running
Running
seanpedrickcase
commited on
Commit
•
115b61f
1
Parent(s):
a3c7fb0
Added AWS auth, logging, allowed for API call saves
Browse files- app.py +68 -23
- tools/auth.py +47 -0
- tools/aws_functions.py +75 -29
- tools/{gradio.py → helper_functions.py} +106 -0
- tools/matcher_funcs.py +51 -32
- tools/preparation.py +1 -101
app.py
CHANGED
@@ -3,11 +3,13 @@ from datetime import datetime
|
|
3 |
from pathlib import Path
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
|
|
6 |
|
7 |
from tools.matcher_funcs import run_matcher
|
8 |
-
from tools.
|
9 |
-
from tools.aws_functions import load_data_from_aws
|
10 |
from tools.constants import output_folder
|
|
|
11 |
|
12 |
import warnings
|
13 |
# Remove warnings from print statements
|
@@ -25,6 +27,15 @@ base_folder = Path(os.getcwd())
|
|
25 |
|
26 |
ensure_output_folder_exists(output_folder)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Create the gradio interface
|
29 |
block = gr.Blocks(theme = gr.themes.Base())
|
30 |
|
@@ -35,6 +46,17 @@ with block:
|
|
35 |
results_data_state = gr.State(pd.DataFrame())
|
36 |
ref_results_data_state =gr.State(pd.DataFrame())
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
gr.Markdown(
|
39 |
"""
|
40 |
# Address matcher
|
@@ -66,7 +88,7 @@ with block:
|
|
66 |
|
67 |
with gr.Accordion("Use Addressbase API (instead of reference file)", open = True):
|
68 |
in_api = gr.Dropdown(label="Choose API type", multiselect=False, value=None, choices=["Postcode"])#["Postcode", "UPRN"]) #choices=["Address", "Postcode", "UPRN"])
|
69 |
-
in_api_key = gr.Textbox(label="Addressbase API key", type='password')
|
70 |
|
71 |
with gr.Accordion("Match against reference file of addresses", open = False):
|
72 |
in_ref = gr.File(label="Input reference addresses from file", file_count= "multiple")
|
@@ -81,6 +103,18 @@ with block:
|
|
81 |
output_summary = gr.Textbox(label="Output summary")
|
82 |
output_file = gr.File(label="Output file")
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
with gr.Tab(label="Advanced options"):
|
85 |
with gr.Accordion(label = "AWS data access", open = False):
|
86 |
aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
|
@@ -90,35 +124,46 @@ with block:
|
|
90 |
|
91 |
aws_log_box = gr.Textbox(label="AWS data load status")
|
92 |
|
93 |
-
|
94 |
### Loading AWS data ###
|
95 |
load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_ref, aws_log_box])
|
96 |
-
|
97 |
|
98 |
# Updates to components
|
99 |
in_file.change(fn = initial_data_load, inputs=[in_file], outputs=[output_summary, in_colnames, in_existing, data_state, results_data_state])
|
100 |
in_ref.change(fn = initial_data_load, inputs=[in_ref], outputs=[output_summary, in_refcol, in_joincol, ref_data_state, ref_results_data_state])
|
101 |
|
102 |
match_btn.click(fn = run_matcher, inputs=[in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, in_api, in_api_key],
|
103 |
-
outputs=[output_summary, output_file], api_name="address")
|
|
|
104 |
|
105 |
|
106 |
-
#
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
-
block.queue().launch(
|
114 |
-
|
115 |
-
block.queue().launch(ssl_verify=False)
|
116 |
-
|
117 |
-
# Download OpenSSL from here:
|
118 |
-
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d
|
119 |
-
#block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
|
120 |
-
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
121 |
-
|
122 |
-
# Running on local server without https
|
123 |
-
#block.queue().launch(server_name="0.0.0.0", server_port=7861, ssl_verify=False)
|
124 |
|
|
|
3 |
from pathlib import Path
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
+
import socket
|
7 |
|
8 |
from tools.matcher_funcs import run_matcher
|
9 |
+
from tools.helper_functions import initial_data_load, ensure_output_folder_exists, get_connection_params, get_or_create_env_var, reveal_feedback_buttons
|
10 |
+
from tools.aws_functions import load_data_from_aws, upload_file_to_s3
|
11 |
from tools.constants import output_folder
|
12 |
+
from tools.auth import authenticate_user
|
13 |
|
14 |
import warnings
|
15 |
# Remove warnings from print statements
|
|
|
27 |
|
28 |
ensure_output_folder_exists(output_folder)
|
29 |
|
30 |
+
host_name = socket.gethostname()
|
31 |
+
|
32 |
+
feedback_logs_folder = 'feedback/' + today_rev + '/' + host_name + '/'
|
33 |
+
access_logs_folder = 'logs/' + today_rev + '/' + host_name + '/'
|
34 |
+
usage_logs_folder = 'usage/' + today_rev + '/' + host_name + '/'
|
35 |
+
|
36 |
+
# Launch the Gradio app
|
37 |
+
ADDRESSBASE_API_KEY = get_or_create_env_var('ADDRESSBASE_API_KEY', '')
|
38 |
+
|
39 |
# Create the gradio interface
|
40 |
block = gr.Blocks(theme = gr.themes.Base())
|
41 |
|
|
|
46 |
results_data_state = gr.State(pd.DataFrame())
|
47 |
ref_results_data_state =gr.State(pd.DataFrame())
|
48 |
|
49 |
+
session_hash_state = gr.State()
|
50 |
+
s3_output_folder_state = gr.State()
|
51 |
+
|
52 |
+
# Logging state
|
53 |
+
feedback_logs_state = gr.State(feedback_logs_folder + 'log.csv')
|
54 |
+
feedback_s3_logs_loc_state = gr.State(feedback_logs_folder)
|
55 |
+
access_logs_state = gr.State(access_logs_folder + 'log.csv')
|
56 |
+
access_s3_logs_loc_state = gr.State(access_logs_folder)
|
57 |
+
usage_logs_state = gr.State(usage_logs_folder + 'log.csv')
|
58 |
+
usage_s3_logs_loc_state = gr.State(usage_logs_folder)
|
59 |
+
|
60 |
gr.Markdown(
|
61 |
"""
|
62 |
# Address matcher
|
|
|
88 |
|
89 |
with gr.Accordion("Use Addressbase API (instead of reference file)", open = True):
|
90 |
in_api = gr.Dropdown(label="Choose API type", multiselect=False, value=None, choices=["Postcode"])#["Postcode", "UPRN"]) #choices=["Address", "Postcode", "UPRN"])
|
91 |
+
in_api_key = gr.Textbox(label="Addressbase API key", type='password', value = ADDRESSBASE_API_KEY)
|
92 |
|
93 |
with gr.Accordion("Match against reference file of addresses", open = False):
|
94 |
in_ref = gr.File(label="Input reference addresses from file", file_count= "multiple")
|
|
|
103 |
output_summary = gr.Textbox(label="Output summary")
|
104 |
output_file = gr.File(label="Output file")
|
105 |
|
106 |
+
feedback_title = gr.Markdown(value="## Please give feedback", visible=False)
|
107 |
+
feedback_radio = gr.Radio(choices=["The results were good", "The results were not good"], visible=False)
|
108 |
+
further_details_text = gr.Textbox(label="Please give more detailed feedback about the results:", visible=False)
|
109 |
+
submit_feedback_btn = gr.Button(value="Submit feedback", visible=False)
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
s3_logs_output_textbox = gr.Textbox(label="Feedback submission logs", visible=False)
|
113 |
+
# This keeps track of the time taken to match files for logging purposes.
|
114 |
+
estimated_time_taken_number = gr.Number(value=0.0, precision=1, visible=False)
|
115 |
+
# Invisible text box to hold the session hash/username just for logging purposes
|
116 |
+
session_hash_textbox = gr.Textbox(value="", visible=False)
|
117 |
+
|
118 |
with gr.Tab(label="Advanced options"):
|
119 |
with gr.Accordion(label = "AWS data access", open = False):
|
120 |
aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
|
|
|
124 |
|
125 |
aws_log_box = gr.Textbox(label="AWS data load status")
|
126 |
|
|
|
127 |
### Loading AWS data ###
|
128 |
load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_ref, aws_log_box])
|
|
|
129 |
|
130 |
# Updates to components
|
131 |
in_file.change(fn = initial_data_load, inputs=[in_file], outputs=[output_summary, in_colnames, in_existing, data_state, results_data_state])
|
132 |
in_ref.change(fn = initial_data_load, inputs=[in_ref], outputs=[output_summary, in_refcol, in_joincol, ref_data_state, ref_results_data_state])
|
133 |
|
134 |
match_btn.click(fn = run_matcher, inputs=[in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, in_api, in_api_key],
|
135 |
+
outputs=[output_summary, output_file, estimated_time_taken_number], api_name="address").\
|
136 |
+
then(fn = reveal_feedback_buttons, outputs=[feedback_radio, further_details_text, submit_feedback_btn, feedback_title])
|
137 |
|
138 |
|
139 |
+
# Get connection details on app load
|
140 |
+
block.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox])
|
141 |
+
|
142 |
+
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
143 |
+
access_callback = gr.CSVLogger()
|
144 |
+
access_callback.setup([session_hash_textbox], access_logs_folder)
|
145 |
+
session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
|
146 |
+
then(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
147 |
+
|
148 |
+
# User submitted feedback for pdf redactions
|
149 |
+
feedback_callback = gr.CSVLogger()
|
150 |
+
feedback_callback.setup([feedback_radio, further_details_text, in_file], feedback_logs_folder)
|
151 |
+
submit_feedback_btn.click(lambda *args: feedback_callback.flag(list(args)), [feedback_radio, further_details_text, in_file], None, preprocess=False).\
|
152 |
+
then(fn = upload_file_to_s3, inputs=[feedback_logs_state, feedback_s3_logs_loc_state], outputs=[further_details_text])
|
153 |
+
|
154 |
+
# Log processing time/token usage when making a query
|
155 |
+
usage_callback = gr.CSVLogger()
|
156 |
+
usage_callback.setup([session_hash_textbox, in_file, estimated_time_taken_number], usage_logs_folder)
|
157 |
+
estimated_time_taken_number.change(lambda *args: usage_callback.flag(list(args)), [session_hash_textbox, in_file, estimated_time_taken_number], None, preprocess=False).\
|
158 |
+
then(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
159 |
+
|
160 |
+
# Launch the Gradio app
|
161 |
+
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
162 |
+
print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
if os.environ['COGNITO_AUTH'] == "1":
|
166 |
+
block.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
|
167 |
else:
|
168 |
+
block.queue().launch(show_error=True, inbrowser=True, max_file_size='50mb')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
tools/auth.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
from tools.helper_functions import get_or_create_env_var
|
3 |
+
|
4 |
+
client_id = get_or_create_env_var('AWS_CLIENT_ID', '') # This client id is borrowed from async gradio app client
|
5 |
+
print(f'The value of AWS_CLIENT_ID is {client_id}')
|
6 |
+
|
7 |
+
user_pool_id = get_or_create_env_var('AWS_USER_POOL_ID', '')
|
8 |
+
print(f'The value of AWS_USER_POOL_ID is {user_pool_id}')
|
9 |
+
|
10 |
+
def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=client_id):
|
11 |
+
"""Authenticates a user against an AWS Cognito user pool.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
user_pool_id (str): The ID of the Cognito user pool.
|
15 |
+
client_id (str): The ID of the Cognito user pool client.
|
16 |
+
username (str): The username of the user.
|
17 |
+
password (str): The password of the user.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
bool: True if the user is authenticated, False otherwise.
|
21 |
+
"""
|
22 |
+
|
23 |
+
client = boto3.client('cognito-idp') # Cognito Identity Provider client
|
24 |
+
|
25 |
+
try:
|
26 |
+
response = client.initiate_auth(
|
27 |
+
AuthFlow='USER_PASSWORD_AUTH',
|
28 |
+
AuthParameters={
|
29 |
+
'USERNAME': username,
|
30 |
+
'PASSWORD': password,
|
31 |
+
},
|
32 |
+
ClientId=client_id
|
33 |
+
)
|
34 |
+
|
35 |
+
# If successful, you'll receive an AuthenticationResult in the response
|
36 |
+
if response.get('AuthenticationResult'):
|
37 |
+
return True
|
38 |
+
else:
|
39 |
+
return False
|
40 |
+
|
41 |
+
except client.exceptions.NotAuthorizedException:
|
42 |
+
return False
|
43 |
+
except client.exceptions.UserNotFoundException:
|
44 |
+
return False
|
45 |
+
except Exception as e:
|
46 |
+
print(f"An error occurred: {e}")
|
47 |
+
return False
|
tools/aws_functions.py
CHANGED
@@ -1,37 +1,46 @@
|
|
1 |
-
from typing import Type
|
2 |
import pandas as pd
|
3 |
import boto3
|
4 |
import tempfile
|
5 |
import os
|
|
|
6 |
|
7 |
PandasDataFrame = Type[pd.DataFrame]
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Download direct from S3 - requires login credentials
|
37 |
def download_file_from_s3(bucket_name, key, local_file_path):
|
@@ -101,8 +110,6 @@ def download_files_from_s3(bucket_name, s3_folder, local_folder, filenames):
|
|
101 |
except Exception as e:
|
102 |
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_name):
|
107 |
|
108 |
temp_dir = tempfile.mkdtemp()
|
@@ -154,3 +161,42 @@ def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_
|
|
154 |
|
155 |
return files, out_message
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, List
|
2 |
import pandas as pd
|
3 |
import boto3
|
4 |
import tempfile
|
5 |
import os
|
6 |
+
from tools.helper_functions import get_or_create_env_var
|
7 |
|
8 |
PandasDataFrame = Type[pd.DataFrame]
|
9 |
|
10 |
+
# Get AWS credentials if required
|
11 |
+
bucket_name=""
|
12 |
+
aws_var = "RUN_AWS_FUNCTIONS"
|
13 |
+
aws_var_default = "0"
|
14 |
+
aws_var_val = get_or_create_env_var(aws_var, aws_var_default)
|
15 |
+
print(f'The value of {aws_var} is {aws_var_val}')
|
16 |
+
|
17 |
+
if aws_var_val == "1":
|
18 |
+
try:
|
19 |
+
session = boto3.Session()
|
20 |
+
bucket_name = os.environ['ADDRESS_MATCHER_BUCKET']
|
21 |
+
except Exception as e:
|
22 |
+
bucket_name = ''
|
23 |
+
print(e)
|
24 |
+
|
25 |
+
def get_assumed_role_info():
|
26 |
+
sts = boto3.client('sts', region_name='eu-west-2', endpoint_url='https://sts.eu-west-2.amazonaws.com')
|
27 |
+
response = sts.get_caller_identity()
|
28 |
+
|
29 |
+
# Extract ARN of the assumed role
|
30 |
+
assumed_role_arn = response['Arn']
|
31 |
+
|
32 |
+
# Extract the name of the assumed role from the ARN
|
33 |
+
assumed_role_name = assumed_role_arn.split('/')[-1]
|
34 |
+
|
35 |
+
return assumed_role_arn, assumed_role_name
|
36 |
+
|
37 |
+
try:
|
38 |
+
assumed_role_arn, assumed_role_name = get_assumed_role_info()
|
39 |
+
|
40 |
+
print("Assumed Role ARN:", assumed_role_arn)
|
41 |
+
print("Assumed Role Name:", assumed_role_name)
|
42 |
+
except Exception as e:
|
43 |
+
print(e)
|
44 |
|
45 |
# Download direct from S3 - requires login credentials
|
46 |
def download_file_from_s3(bucket_name, key, local_file_path):
|
|
|
110 |
except Exception as e:
|
111 |
print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
|
112 |
|
|
|
|
|
113 |
def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_name):
|
114 |
|
115 |
temp_dir = tempfile.mkdtemp()
|
|
|
161 |
|
162 |
return files, out_message
|
163 |
|
164 |
+
def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=bucket_name):
|
165 |
+
"""
|
166 |
+
Uploads a file from local machine to Amazon S3.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
- local_file_path: Local file path(s) of the file(s) to upload.
|
170 |
+
- s3_key: Key (path) to the file in the S3 bucket.
|
171 |
+
- s3_bucket: Name of the S3 bucket.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
- Message as variable/printed to console
|
175 |
+
"""
|
176 |
+
final_out_message = []
|
177 |
+
|
178 |
+
s3_client = boto3.client('s3')
|
179 |
+
|
180 |
+
if isinstance(local_file_paths, str):
|
181 |
+
local_file_paths = [local_file_paths]
|
182 |
+
|
183 |
+
for file in local_file_paths:
|
184 |
+
try:
|
185 |
+
# Get file name off file path
|
186 |
+
file_name = os.path.basename(file)
|
187 |
+
|
188 |
+
s3_key_full = s3_key + file_name
|
189 |
+
print("S3 key: ", s3_key_full)
|
190 |
+
|
191 |
+
s3_client.upload_file(file, s3_bucket, s3_key_full)
|
192 |
+
out_message = "File " + file_name + " uploaded successfully!"
|
193 |
+
print(out_message)
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
out_message = f"Error uploading file(s): {e}"
|
197 |
+
print(out_message)
|
198 |
+
|
199 |
+
final_out_message.append(out_message)
|
200 |
+
final_out_message_str = '\n'.join(final_out_message)
|
201 |
+
|
202 |
+
return final_out_message_str
|
tools/{gradio.py → helper_functions.py}
RENAMED
@@ -1,6 +1,25 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def detect_file_type(filename):
|
6 |
"""Detect the file type based on its extension."""
|
@@ -70,7 +89,94 @@ def dummy_function(in_colnames):
|
|
70 |
"""
|
71 |
return None
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
def clear_inputs(in_file, in_ref, in_text):
|
75 |
return gr.File(value=[]), gr.File(value=[]), gr.Textbox(value='')
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
def get_or_create_env_var(var_name, default_value):
|
7 |
+
# Get the environment variable if it exists
|
8 |
+
value = os.environ.get(var_name)
|
9 |
+
|
10 |
+
# If it doesn't exist, set it to the default value
|
11 |
+
if value is None:
|
12 |
+
os.environ[var_name] = default_value
|
13 |
+
value = default_value
|
14 |
+
|
15 |
+
return value
|
16 |
+
|
17 |
+
# Retrieving or setting output folder
|
18 |
+
env_var_name = 'GRADIO_OUTPUT_FOLDER'
|
19 |
+
default_value = 'output/'
|
20 |
+
|
21 |
+
output_folder = get_or_create_env_var(env_var_name, default_value)
|
22 |
+
print(f'The value of {env_var_name} is {output_folder}')
|
23 |
|
24 |
def detect_file_type(filename):
|
25 |
"""Detect the file type based on its extension."""
|
|
|
89 |
"""
|
90 |
return None
|
91 |
|
92 |
+
# Upon running a process, the feedback buttons are revealed
|
93 |
+
def reveal_feedback_buttons():
|
94 |
+
return gr.Radio(visible=True), gr.Textbox(visible=True), gr.Button(visible=True), gr.Markdown(visible=True)
|
95 |
|
96 |
def clear_inputs(in_file, in_ref, in_text):
|
97 |
return gr.File(value=[]), gr.File(value=[]), gr.Textbox(value='')
|
98 |
|
99 |
+
## Get final processing time for logs:
|
100 |
+
def sum_numbers_before_seconds(string):
|
101 |
+
"""Extracts numbers that precede the word 'seconds' from a string and adds them up.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
string: The input string.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
The sum of all numbers before 'seconds' in the string.
|
108 |
+
"""
|
109 |
+
|
110 |
+
# Extract numbers before 'seconds' using regular expression
|
111 |
+
numbers = re.findall(r'\d+(\.\d+)?\s*seconds', string)
|
112 |
+
|
113 |
+
# Extract the numbers from the matches
|
114 |
+
numbers = [float(num.split()[0]) for num in numbers]
|
115 |
+
|
116 |
+
# Sum up the extracted numbers
|
117 |
+
sum_of_numbers = sum(numbers)
|
118 |
+
|
119 |
+
return sum_of_numbers
|
120 |
+
|
121 |
+
async def get_connection_params(request: gr.Request):
|
122 |
+
base_folder = ""
|
123 |
+
|
124 |
+
if request:
|
125 |
+
#print("request user:", request.username)
|
126 |
+
|
127 |
+
#request_data = await request.json() # Parse JSON body
|
128 |
+
#print("All request data:", request_data)
|
129 |
+
#context_value = request_data.get('context')
|
130 |
+
#if 'context' in request_data:
|
131 |
+
# print("Request context dictionary:", request_data['context'])
|
132 |
+
|
133 |
+
# print("Request headers dictionary:", request.headers)
|
134 |
+
# print("All host elements", request.client)
|
135 |
+
# print("IP address:", request.client.host)
|
136 |
+
# print("Query parameters:", dict(request.query_params))
|
137 |
+
# To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
|
138 |
+
#print("Request dictionary to object:", request.request.body())
|
139 |
+
print("Session hash:", request.session_hash)
|
140 |
+
|
141 |
+
# Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
|
142 |
+
CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
|
143 |
+
#print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
|
144 |
+
|
145 |
+
# Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
|
146 |
+
CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
|
147 |
+
#print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
|
148 |
+
|
149 |
+
if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
150 |
+
if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
|
151 |
+
supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
|
152 |
+
if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
153 |
+
print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
|
154 |
+
else:
|
155 |
+
raise(ValueError, "Custom Cloudfront header value does not match expected value.")
|
156 |
+
|
157 |
+
# Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
|
158 |
+
|
159 |
+
if request.username:
|
160 |
+
out_session_hash = request.username
|
161 |
+
base_folder = "user-files/"
|
162 |
+
print("Request username found:", out_session_hash)
|
163 |
+
|
164 |
+
elif 'x-cognito-id' in request.headers:
|
165 |
+
out_session_hash = request.headers['x-cognito-id']
|
166 |
+
base_folder = "user-files/"
|
167 |
+
print("Cognito ID found:", out_session_hash)
|
168 |
+
|
169 |
+
else:
|
170 |
+
out_session_hash = request.session_hash
|
171 |
+
base_folder = "temp-files/"
|
172 |
+
# print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
|
173 |
+
|
174 |
+
output_folder = base_folder + out_session_hash + "/"
|
175 |
+
#if bucket_name:
|
176 |
+
# print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
|
177 |
+
|
178 |
+
return out_session_hash, output_folder, out_session_hash
|
179 |
+
else:
|
180 |
+
print("No session parameters found.")
|
181 |
+
return "",""
|
182 |
+
|
tools/matcher_funcs.py
CHANGED
@@ -33,7 +33,7 @@ from tools.standardise import standardise_wrapper_func
|
|
33 |
### Predict function for imported model
|
34 |
from tools.model_predict import full_predict_func, full_predict_torch, post_predict_clean
|
35 |
from tools.recordlinkage_funcs import score_based_match
|
36 |
-
from tools.
|
37 |
|
38 |
# API functions
|
39 |
from tools.addressbase_api_funcs import places_api_query
|
@@ -108,10 +108,13 @@ def filter_not_matched(
|
|
108 |
|
109 |
return search_df.iloc[np.where(~matched)[0]]
|
110 |
|
111 |
-
def
|
|
|
|
|
|
|
112 |
if in_api_key == "":
|
113 |
print ("No API key provided, please provide one to continue")
|
114 |
-
return Matcher
|
115 |
else:
|
116 |
# Call the API
|
117 |
#Matcher.ref_df = pd.DataFrame()
|
@@ -119,7 +122,7 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
|
|
119 |
# Check if the ref_df file already exists
|
120 |
def check_and_create_api_folder():
|
121 |
# Check if the environmental variable is available
|
122 |
-
file_path = os.environ.get('ADDRESSBASE_API_OUT')
|
123 |
|
124 |
if file_path is None:
|
125 |
# Environmental variable is not set
|
@@ -145,11 +148,13 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
|
|
145 |
api_ref_save_loc = api_output_folder + search_file_name_without_extension + "_api_" + today_month_rev + "_" + query_type + "_ckpt"
|
146 |
print("API reference save location: ", api_ref_save_loc)
|
147 |
|
|
|
|
|
148 |
# Allow for csv, parquet and gzipped csv files
|
149 |
if os.path.isfile(api_ref_save_loc + ".csv"):
|
150 |
print("API reference CSV file found")
|
151 |
Matcher.ref_df = pd.read_csv(api_ref_save_loc + ".csv")
|
152 |
-
elif os.path.isfile(
|
153 |
print("API reference Parquet file found")
|
154 |
Matcher.ref_df = pd.read_parquet(api_ref_save_loc + ".parquet")
|
155 |
elif os.path.isfile(api_ref_save_loc + ".csv.gz"):
|
@@ -350,21 +355,23 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
|
|
350 |
# Matcher.ref_df = Matcher.ref_df.loc[Matcher.ref_df["LOCAL_CUSTODIAN_CODE"] != 7655,:]
|
351 |
|
352 |
if save_file:
|
|
|
353 |
print("Saving reference file to: " + api_ref_save_loc[:-5] + ".parquet")
|
354 |
Matcher.ref_df.to_parquet(output_folder + api_ref_save_loc + ".parquet", index=False) # Save checkpoint as well
|
355 |
-
Matcher.ref_df.to_parquet(
|
356 |
|
357 |
if Matcher.ref_df.empty:
|
358 |
print ("No reference data found with API")
|
359 |
return Matcher
|
360 |
|
361 |
-
return Matcher
|
362 |
|
363 |
-
def
|
364 |
'''
|
365 |
Check for reference address data, do some preprocessing, and load in from the Addressbase API if required.
|
366 |
'''
|
367 |
-
|
|
|
368 |
# Check if reference data loaded, bring in if already there
|
369 |
if not ref_data_state.empty:
|
370 |
Matcher.ref_df = ref_data_state
|
@@ -382,10 +389,10 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
|
|
382 |
if not in_ref:
|
383 |
if in_api==False:
|
384 |
print ("No reference file provided, please provide one to continue")
|
385 |
-
return Matcher
|
386 |
# Check if api call required and api key is provided
|
387 |
else:
|
388 |
-
Matcher =
|
389 |
|
390 |
else:
|
391 |
Matcher.ref_name = get_file_name(in_ref[0].name)
|
@@ -402,9 +409,7 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
|
|
402 |
|
403 |
Matcher.ref_df = pd.concat([Matcher.ref_df, temp_ref_file])
|
404 |
|
405 |
-
# For the neural net model to work, the llpg columns have to be in the LPI format (e.g. with columns SaoText, SaoStartNumber etc. Here we check if we have that format.
|
406 |
-
|
407 |
-
|
408 |
|
409 |
if 'Address_LPI' in Matcher.ref_df.columns:
|
410 |
Matcher.ref_df = Matcher.ref_df.rename(columns={
|
@@ -475,9 +480,9 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
|
|
475 |
Matcher.ref_df = Matcher.ref_df.reset_index() #.drop(["index","level_0"], axis = 1, errors="ignore").reset_index().drop(["index","level_0"], axis = 1, errors="ignore")
|
476 |
Matcher.ref_df.index.name = 'index'
|
477 |
|
478 |
-
return Matcher
|
479 |
|
480 |
-
def
|
481 |
'''
|
482 |
Check if data to be matched exists. Filter it according to which records are relevant in the reference dataset
|
483 |
'''
|
@@ -654,6 +659,8 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
|
|
654 |
'''
|
655 |
Load in user inputs from the Gradio interface. Convert all input types (single address, or csv input) into standardised data format that can be used downstream for the fuzzy matching.
|
656 |
'''
|
|
|
|
|
657 |
today_rev = datetime.now().strftime("%Y%m%d")
|
658 |
|
659 |
# Abort flag for if it's not even possible to attempt the first stage of the match for some reason
|
@@ -662,16 +669,15 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
|
|
662 |
### ref_df FILES ###
|
663 |
# If not an API call, run this first
|
664 |
if not in_api:
|
665 |
-
Matcher =
|
666 |
|
667 |
### MATCH/SEARCH FILES ###
|
668 |
# If doing API calls, we need to know the search data before querying for specific addresses/postcodes
|
669 |
-
Matcher =
|
670 |
|
671 |
# If an API call, ref_df data is loaded after
|
672 |
if in_api:
|
673 |
-
|
674 |
-
Matcher = check_ref_data_exists(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
|
675 |
|
676 |
print("Shape of ref_df after filtering is: ", Matcher.ref_df.shape)
|
677 |
print("Shape of search_df after filtering is: ", Matcher.search_df.shape)
|
@@ -682,23 +688,31 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
|
|
682 |
Matcher.match_results_output.to_csv(Matcher.match_outputs_name, index = None)
|
683 |
Matcher.results_on_orig_df.to_csv(Matcher.results_orig_df_name, index = None)
|
684 |
|
685 |
-
return Matcher
|
686 |
|
687 |
# Run whole matcher process
|
688 |
def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame, results_data_state:PandasDataFrame, ref_data_state:PandasDataFrame, in_colnames:List[str], in_refcol:List[str], in_joincol:List[str], in_existing:List[str], in_api:str, in_api_key:str, InitMatch:MatcherClass = InitMatch, progress=gr.Progress()):
|
689 |
'''
|
690 |
Split search and reference data into batches. Loop and run through the match script for each batch of data.
|
691 |
'''
|
|
|
|
|
|
|
692 |
|
693 |
overall_tic = time.perf_counter()
|
694 |
|
695 |
# Load in initial data. This will filter to relevant addresses in the search and reference datasets that can potentially be matched, and will pull in API data if asked for.
|
696 |
-
InitMatch = load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, InitMatch, in_api, in_api_key)
|
|
|
|
|
|
|
697 |
|
698 |
if InitMatch.search_df.empty or InitMatch.ref_df.empty:
|
699 |
out_message = "Nothing to match!"
|
700 |
print(out_message)
|
701 |
-
|
|
|
|
|
702 |
|
703 |
# Run initial address preparation and standardisation processes
|
704 |
# Prepare address format
|
@@ -801,7 +815,7 @@ def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame
|
|
801 |
"Excluded from search":False,
|
802 |
"Matched with reference address":False})
|
803 |
else:
|
804 |
-
summary_of_summaries, BatchMatch_out =
|
805 |
|
806 |
OutputMatch = combine_two_matches(OutputMatch, BatchMatch_out, "All up to and including batch " + str(n+1))
|
807 |
|
@@ -837,7 +851,13 @@ def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame
|
|
837 |
|
838 |
final_summary = fuzzy_not_std_summary + "\n" + fuzzy_std_summary + "\n" + nnet_std_summary + "\n" + time_out
|
839 |
|
840 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
|
842 |
# Run a match run for a single batch
|
843 |
def create_simple_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:int, ref_batch_size:int):
|
@@ -963,7 +983,7 @@ def create_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:i
|
|
963 |
|
964 |
return lengths_df
|
965 |
|
966 |
-
def
|
967 |
'''
|
968 |
Over-arching function for running a single batch of data through the full matching process. Calls fuzzy matching, then neural network match functions in order. It outputs a summary of the match, and a MatcherClass with the matched data included.
|
969 |
'''
|
@@ -979,7 +999,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
|
|
979 |
|
980 |
''' Run fuzzy match on non-standardised dataset '''
|
981 |
|
982 |
-
FuzzyNotStdMatch =
|
983 |
|
984 |
if FuzzyNotStdMatch.abort_flag == True:
|
985 |
message = "Nothing to match! Aborting address check."
|
@@ -999,7 +1019,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
|
|
999 |
progress(.25, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Fuzzy match - standardised dataset")
|
1000 |
df_name = "Fuzzy standardised"
|
1001 |
|
1002 |
-
FuzzyStdMatch =
|
1003 |
FuzzyStdMatch = combine_two_matches(FuzzyNotStdMatch, FuzzyStdMatch, df_name)
|
1004 |
|
1005 |
''' Continue if reference file in correct format, and neural net model exists. Also if data not too long '''
|
@@ -1022,7 +1042,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
|
|
1022 |
progress(.50, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - non-standardised dataset")
|
1023 |
df_name = "Neural net not standardised"
|
1024 |
|
1025 |
-
FuzzyNNetNotStdMatch =
|
1026 |
FuzzyNNetNotStdMatch = combine_two_matches(FuzzyStdMatch, FuzzyNNetNotStdMatch, df_name)
|
1027 |
|
1028 |
if (len(FuzzyNNetNotStdMatch.search_df_not_matched) == 0):
|
@@ -1035,7 +1055,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
|
|
1035 |
progress(.75, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - standardised dataset")
|
1036 |
df_name = "Neural net standardised"
|
1037 |
|
1038 |
-
FuzzyNNetStdMatch =
|
1039 |
FuzzyNNetStdMatch = combine_two_matches(FuzzyNNetNotStdMatch, FuzzyNNetStdMatch, df_name)
|
1040 |
|
1041 |
if run_fuzzy_match == False:
|
@@ -1052,7 +1072,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
|
|
1052 |
return summary_of_summaries, FuzzyNNetStdMatch
|
1053 |
|
1054 |
# Overarching functions
|
1055 |
-
def
|
1056 |
|
1057 |
today_rev = datetime.now().strftime("%Y%m%d")
|
1058 |
|
@@ -1463,7 +1483,6 @@ def full_nn_match(ref_address_cols:List[str],
|
|
1463 |
|
1464 |
return match_results_output_final_three, results_on_orig_df, summary_three, predict_df
|
1465 |
|
1466 |
-
|
1467 |
# Combiner/summary functions
|
1468 |
def combine_dfs_and_remove_dups(orig_df:PandasDataFrame, new_df:PandasDataFrame, index_col:str = "search_orig_address", match_address_series:str = "full_match", keep_only_duplicated:bool = False) -> PandasDataFrame:
|
1469 |
|
|
|
33 |
### Predict function for imported model
|
34 |
from tools.model_predict import full_predict_func, full_predict_torch, post_predict_clean
|
35 |
from tools.recordlinkage_funcs import score_based_match
|
36 |
+
from tools.helper_functions import initial_data_load, sum_numbers_before_seconds
|
37 |
|
38 |
# API functions
|
39 |
from tools.addressbase_api_funcs import places_api_query
|
|
|
108 |
|
109 |
return search_df.iloc[np.where(~matched)[0]]
|
110 |
|
111 |
+
def query_addressbase_api(in_api_key:str, Matcher:MatcherClass, query_type:str, progress=gr.Progress()):
|
112 |
+
|
113 |
+
final_api_output_file_name = ""
|
114 |
+
|
115 |
if in_api_key == "":
|
116 |
print ("No API key provided, please provide one to continue")
|
117 |
+
return Matcher, final_api_output_file_name
|
118 |
else:
|
119 |
# Call the API
|
120 |
#Matcher.ref_df = pd.DataFrame()
|
|
|
122 |
# Check if the ref_df file already exists
|
123 |
def check_and_create_api_folder():
|
124 |
# Check if the environmental variable is available
|
125 |
+
file_path = os.environ.get('ADDRESSBASE_API_OUT')
|
126 |
|
127 |
if file_path is None:
|
128 |
# Environmental variable is not set
|
|
|
148 |
api_ref_save_loc = api_output_folder + search_file_name_without_extension + "_api_" + today_month_rev + "_" + query_type + "_ckpt"
|
149 |
print("API reference save location: ", api_ref_save_loc)
|
150 |
|
151 |
+
final_api_output_file_name = api_ref_save_loc + ".parquet"
|
152 |
+
|
153 |
# Allow for csv, parquet and gzipped csv files
|
154 |
if os.path.isfile(api_ref_save_loc + ".csv"):
|
155 |
print("API reference CSV file found")
|
156 |
Matcher.ref_df = pd.read_csv(api_ref_save_loc + ".csv")
|
157 |
+
elif os.path.isfile(final_api_output_file_name):
|
158 |
print("API reference Parquet file found")
|
159 |
Matcher.ref_df = pd.read_parquet(api_ref_save_loc + ".parquet")
|
160 |
elif os.path.isfile(api_ref_save_loc + ".csv.gz"):
|
|
|
355 |
# Matcher.ref_df = Matcher.ref_df.loc[Matcher.ref_df["LOCAL_CUSTODIAN_CODE"] != 7655,:]
|
356 |
|
357 |
if save_file:
|
358 |
+
final_api_output_file_name = output_folder + api_ref_save_loc[:-5] + ".parquet"
|
359 |
print("Saving reference file to: " + api_ref_save_loc[:-5] + ".parquet")
|
360 |
Matcher.ref_df.to_parquet(output_folder + api_ref_save_loc + ".parquet", index=False) # Save checkpoint as well
|
361 |
+
Matcher.ref_df.to_parquet(final_api_output_file_name, index=False)
|
362 |
|
363 |
if Matcher.ref_df.empty:
|
364 |
print ("No reference data found with API")
|
365 |
return Matcher
|
366 |
|
367 |
+
return Matcher, final_api_output_file_name
|
368 |
|
369 |
+
def load_ref_data(Matcher:MatcherClass, ref_data_state:PandasDataFrame, in_ref:List[str], in_refcol:List[str], in_api:List[str], in_api_key:str, query_type:str, progress=gr.Progress()):
|
370 |
'''
|
371 |
Check for reference address data, do some preprocessing, and load in from the Addressbase API if required.
|
372 |
'''
|
373 |
+
final_api_output_file_name = ""
|
374 |
+
|
375 |
# Check if reference data loaded, bring in if already there
|
376 |
if not ref_data_state.empty:
|
377 |
Matcher.ref_df = ref_data_state
|
|
|
389 |
if not in_ref:
|
390 |
if in_api==False:
|
391 |
print ("No reference file provided, please provide one to continue")
|
392 |
+
return Matcher, final_api_output_file_name
|
393 |
# Check if api call required and api key is provided
|
394 |
else:
|
395 |
+
Matcher, final_api_output_file_name = query_addressbase_api(in_api_key, Matcher, query_type)
|
396 |
|
397 |
else:
|
398 |
Matcher.ref_name = get_file_name(in_ref[0].name)
|
|
|
409 |
|
410 |
Matcher.ref_df = pd.concat([Matcher.ref_df, temp_ref_file])
|
411 |
|
412 |
+
# For the neural net model to work, the llpg columns have to be in the LPI format (e.g. with columns SaoText, SaoStartNumber etc. Here we check if we have that format.
|
|
|
|
|
413 |
|
414 |
if 'Address_LPI' in Matcher.ref_df.columns:
|
415 |
Matcher.ref_df = Matcher.ref_df.rename(columns={
|
|
|
480 |
Matcher.ref_df = Matcher.ref_df.reset_index() #.drop(["index","level_0"], axis = 1, errors="ignore").reset_index().drop(["index","level_0"], axis = 1, errors="ignore")
|
481 |
Matcher.ref_df.index.name = 'index'
|
482 |
|
483 |
+
return Matcher, final_api_output_file_name
|
484 |
|
485 |
+
def load_match_data_and_filter(Matcher:MatcherClass, data_state:PandasDataFrame, results_data_state:PandasDataFrame, in_file:List[str], in_text:str, in_colnames:List[str], in_joincol:List[str], in_existing:List[str], in_api:List[str]):
|
486 |
'''
|
487 |
Check if data to be matched exists. Filter it according to which records are relevant in the reference dataset
|
488 |
'''
|
|
|
659 |
'''
|
660 |
Load in user inputs from the Gradio interface. Convert all input types (single address, or csv input) into standardised data format that can be used downstream for the fuzzy matching.
|
661 |
'''
|
662 |
+
final_api_output_file_name = ""
|
663 |
+
|
664 |
today_rev = datetime.now().strftime("%Y%m%d")
|
665 |
|
666 |
# Abort flag for if it's not even possible to attempt the first stage of the match for some reason
|
|
|
669 |
### ref_df FILES ###
|
670 |
# If not an API call, run this first
|
671 |
if not in_api:
|
672 |
+
Matcher, final_api_output_file_name = load_ref_data(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
|
673 |
|
674 |
### MATCH/SEARCH FILES ###
|
675 |
# If doing API calls, we need to know the search data before querying for specific addresses/postcodes
|
676 |
+
Matcher = load_match_data_and_filter(Matcher, data_state, results_data_state, in_file, in_text, in_colnames, in_joincol, in_existing, in_api)
|
677 |
|
678 |
# If an API call, ref_df data is loaded after
|
679 |
if in_api:
|
680 |
+
Matcher, final_api_output_file_name = load_ref_data(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
|
|
|
681 |
|
682 |
print("Shape of ref_df after filtering is: ", Matcher.ref_df.shape)
|
683 |
print("Shape of search_df after filtering is: ", Matcher.search_df.shape)
|
|
|
688 |
Matcher.match_results_output.to_csv(Matcher.match_outputs_name, index = None)
|
689 |
Matcher.results_on_orig_df.to_csv(Matcher.results_orig_df_name, index = None)
|
690 |
|
691 |
+
return Matcher, final_api_output_file_name
|
692 |
|
693 |
# Run whole matcher process
|
694 |
def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame, results_data_state:PandasDataFrame, ref_data_state:PandasDataFrame, in_colnames:List[str], in_refcol:List[str], in_joincol:List[str], in_existing:List[str], in_api:str, in_api_key:str, InitMatch:MatcherClass = InitMatch, progress=gr.Progress()):
|
695 |
'''
|
696 |
Split search and reference data into batches. Loop and run through the match script for each batch of data.
|
697 |
'''
|
698 |
+
output_files = []
|
699 |
+
|
700 |
+
estimate_total_processing_time = 0.0
|
701 |
|
702 |
overall_tic = time.perf_counter()
|
703 |
|
704 |
# Load in initial data. This will filter to relevant addresses in the search and reference datasets that can potentially be matched, and will pull in API data if asked for.
|
705 |
+
InitMatch, final_api_output_file_name = load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, InitMatch, in_api, in_api_key)
|
706 |
+
|
707 |
+
if final_api_output_file_name:
|
708 |
+
output_files.append(final_api_output_file_name)
|
709 |
|
710 |
if InitMatch.search_df.empty or InitMatch.ref_df.empty:
|
711 |
out_message = "Nothing to match!"
|
712 |
print(out_message)
|
713 |
+
|
714 |
+
output_files.extend([InitMatch.results_orig_df_name, InitMatch.match_outputs_name])
|
715 |
+
return out_message, output_files, estimate_total_processing_time
|
716 |
|
717 |
# Run initial address preparation and standardisation processes
|
718 |
# Prepare address format
|
|
|
815 |
"Excluded from search":False,
|
816 |
"Matched with reference address":False})
|
817 |
else:
|
818 |
+
summary_of_summaries, BatchMatch_out = run_single_match_batch(BatchMatch, n, number_of_batches)
|
819 |
|
820 |
OutputMatch = combine_two_matches(OutputMatch, BatchMatch_out, "All up to and including batch " + str(n+1))
|
821 |
|
|
|
851 |
|
852 |
final_summary = fuzzy_not_std_summary + "\n" + fuzzy_std_summary + "\n" + nnet_std_summary + "\n" + time_out
|
853 |
|
854 |
+
|
855 |
+
|
856 |
+
estimate_total_processing_time = sum_numbers_before_seconds(time_out)
|
857 |
+
print("Estimated total processing time:", str(estimate_total_processing_time))
|
858 |
+
|
859 |
+
output_files.extend([OutputMatch.results_orig_df_name, OutputMatch.match_outputs_name])
|
860 |
+
return final_summary, output_files, estimate_total_processing_time
|
861 |
|
862 |
# Run a match run for a single batch
|
863 |
def create_simple_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:int, ref_batch_size:int):
|
|
|
983 |
|
984 |
return lengths_df
|
985 |
|
986 |
+
def run_single_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, progress=gr.Progress()):
|
987 |
'''
|
988 |
Over-arching function for running a single batch of data through the full matching process. Calls fuzzy matching, then neural network match functions in order. It outputs a summary of the match, and a MatcherClass with the matched data included.
|
989 |
'''
|
|
|
999 |
|
1000 |
''' Run fuzzy match on non-standardised dataset '''
|
1001 |
|
1002 |
+
FuzzyNotStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(InitialMatch), standardise = False, nnet = False, file_stub= "not_std_", df_name = df_name)
|
1003 |
|
1004 |
if FuzzyNotStdMatch.abort_flag == True:
|
1005 |
message = "Nothing to match! Aborting address check."
|
|
|
1019 |
progress(.25, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Fuzzy match - standardised dataset")
|
1020 |
df_name = "Fuzzy standardised"
|
1021 |
|
1022 |
+
FuzzyStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyNotStdMatch), standardise = True, nnet = False, file_stub= "std_", df_name = df_name)
|
1023 |
FuzzyStdMatch = combine_two_matches(FuzzyNotStdMatch, FuzzyStdMatch, df_name)
|
1024 |
|
1025 |
''' Continue if reference file in correct format, and neural net model exists. Also if data not too long '''
|
|
|
1042 |
progress(.50, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - non-standardised dataset")
|
1043 |
df_name = "Neural net not standardised"
|
1044 |
|
1045 |
+
FuzzyNNetNotStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyStdMatch), standardise = False, nnet = True, file_stub= "nnet_not_std_", df_name = df_name)
|
1046 |
FuzzyNNetNotStdMatch = combine_two_matches(FuzzyStdMatch, FuzzyNNetNotStdMatch, df_name)
|
1047 |
|
1048 |
if (len(FuzzyNNetNotStdMatch.search_df_not_matched) == 0):
|
|
|
1055 |
progress(.75, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - standardised dataset")
|
1056 |
df_name = "Neural net standardised"
|
1057 |
|
1058 |
+
FuzzyNNetStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyNNetNotStdMatch), standardise = True, nnet = True, file_stub= "nnet_std_", df_name = df_name)
|
1059 |
FuzzyNNetStdMatch = combine_two_matches(FuzzyNNetNotStdMatch, FuzzyNNetStdMatch, df_name)
|
1060 |
|
1061 |
if run_fuzzy_match == False:
|
|
|
1072 |
return summary_of_summaries, FuzzyNNetStdMatch
|
1073 |
|
1074 |
# Overarching functions
|
1075 |
+
def orchestrate_single_match_batch(Matcher, standardise = False, nnet = False, file_stub= "not_std_", df_name = "Fuzzy not standardised"):
|
1076 |
|
1077 |
today_rev = datetime.now().strftime("%Y%m%d")
|
1078 |
|
|
|
1483 |
|
1484 |
return match_results_output_final_three, results_on_orig_df, summary_three, predict_df
|
1485 |
|
|
|
1486 |
# Combiner/summary functions
|
1487 |
def combine_dfs_and_remove_dups(orig_df:PandasDataFrame, new_df:PandasDataFrame, index_col:str = "search_orig_address", match_address_series:str = "full_match", keep_only_duplicated:bool = False) -> PandasDataFrame:
|
1488 |
|
tools/preparation.py
CHANGED
@@ -49,49 +49,6 @@ def prepare_search_address_string(
|
|
49 |
|
50 |
return search_df_out, key_field, address_cols, postcode_col
|
51 |
|
52 |
-
# def prepare_search_address(
|
53 |
-
# search_df: pd.DataFrame,
|
54 |
-
# address_cols: list,
|
55 |
-
# postcode_col: list,
|
56 |
-
# key_col: str
|
57 |
-
# ) -> Tuple[pd.DataFrame, str]:
|
58 |
-
|
59 |
-
# # Validate inputs
|
60 |
-
# if not isinstance(search_df, pd.DataFrame):
|
61 |
-
# raise TypeError("search_df must be a Pandas DataFrame")
|
62 |
-
|
63 |
-
# if not isinstance(address_cols, list):
|
64 |
-
# raise TypeError("address_cols must be a list")
|
65 |
-
|
66 |
-
# if not isinstance(postcode_col, list):
|
67 |
-
# raise TypeError("postcode_col must be a list")
|
68 |
-
|
69 |
-
# if not isinstance(key_col, str):
|
70 |
-
# raise TypeError("key_col must be a string")
|
71 |
-
|
72 |
-
# # Clean address columns
|
73 |
-
# clean_addresses = _clean_columns(search_df, address_cols)
|
74 |
-
|
75 |
-
# # Join address columns into one
|
76 |
-
# full_addresses = _join_address(clean_addresses, address_cols)
|
77 |
-
|
78 |
-
# # Add postcode column
|
79 |
-
# full_df = _add_postcode_column(full_addresses, postcode_col)
|
80 |
-
|
81 |
-
# # Remove postcode from main address if there was only one column in the input
|
82 |
-
# if postcode_col == "full_address_postcode":
|
83 |
-
# # Remove postcode from address
|
84 |
-
# address_series = remove_postcode(search_df, "full_address")
|
85 |
-
# search_df["full_address"] == address_series
|
86 |
-
|
87 |
-
# # Ensure index column
|
88 |
-
# final_df = _ensure_index(full_df, key_col)
|
89 |
-
|
90 |
-
# #print(final_df)
|
91 |
-
|
92 |
-
|
93 |
-
# return final_df, key_col
|
94 |
-
|
95 |
def prepare_search_address(
|
96 |
search_df: pd.DataFrame,
|
97 |
address_cols: list,
|
@@ -145,25 +102,7 @@ def _clean_columns(df, cols):
|
|
145 |
df[cols] = df[cols].apply(clean_col)
|
146 |
|
147 |
return df
|
148 |
-
|
149 |
-
# def _clean_columns(df, cols):
|
150 |
-
# # Cleaning logic
|
151 |
-
# #print(df)
|
152 |
-
|
153 |
-
# #if isinstance(df, pl.DataFrame):
|
154 |
-
# # print("It's a Polars DataFrame")
|
155 |
-
|
156 |
-
# def clean_col(col):
|
157 |
-
# col = col.str.replace("nan", "")
|
158 |
-
# col = col.apply(lambda x: re.sub(r'\s{2,}', ' ', str(x)), skip_nulls=False, return_dtype=str) # replace any spaces greater than one with one
|
159 |
-
# return col.str.replace(",", " ").str.strip() # replace commas with a space
|
160 |
-
|
161 |
-
# for col in cols:
|
162 |
-
# df = df.with_columns(clean_col(df[col]).alias(col))
|
163 |
-
|
164 |
-
# return df
|
165 |
-
|
166 |
-
|
167 |
def _join_address(df, cols):
|
168 |
# Joining logic
|
169 |
full_address = df[cols].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
|
@@ -289,43 +228,6 @@ def prepare_ref_address(ref_df, ref_address_cols, new_join_col = ['UPRN'], stand
|
|
289 |
|
290 |
return ref_df_cleaned
|
291 |
|
292 |
-
# def prepare_ref_address(ref_df:pl.DataFrame, ref_address_cols, new_join_col = ['UPRN'], standard_cols = True):
|
293 |
-
|
294 |
-
# if ('SaoText' in ref_df.columns) | ("Secondary_Name_LPI" in ref_df.columns):
|
295 |
-
# standard_cols = True
|
296 |
-
# else:
|
297 |
-
# standard_cols = False
|
298 |
-
|
299 |
-
# ref_address_cols_uprn = list(ref_address_cols) + new_join_col
|
300 |
-
# ref_df_cleaned = ref_df[ref_address_cols_uprn].fill_null("")
|
301 |
-
|
302 |
-
# # In on-prem LPI db street has been excluded, so put this back in
|
303 |
-
# if ('Street' not in ref_df_cleaned.columns) & ('Address_LPI' in ref_df_cleaned.columns):
|
304 |
-
# ref_df_cleaned = ref_df_cleaned.with_column(pl.col('Address_LPI').apply(lambda x: extract_street_name(x)).alias('Street'))
|
305 |
-
|
306 |
-
# if ('Organisation' not in ref_df_cleaned.columns) & ('SaoText' in ref_df_cleaned.columns):
|
307 |
-
# ref_df_cleaned = ref_df_cleaned.with_column(pl.lit("").alias('Organisation'))
|
308 |
-
|
309 |
-
# #ref_df_cleaned['fulladdress'] =
|
310 |
-
|
311 |
-
# if standard_cols:
|
312 |
-
# pass
|
313 |
-
# # I can not write the full address code here as it depends on your extract_street_name and create_full_address function implementations.
|
314 |
-
# # However, you might need to convert string types to object type for full address creation which may require more than just a few lines of codes.
|
315 |
-
# else:
|
316 |
-
# pass
|
317 |
-
|
318 |
-
# # I can not write the full address code here as it depends on your extract_street_name and create_full_address function implementations.
|
319 |
-
|
320 |
-
# if 'Street' not in ref_df_cleaned.columns:
|
321 |
-
# ref_df_cleaned = ref_df_cleaned.with_column(pl.col('fulladdress').apply(extract_street_name).alias("Street"))
|
322 |
-
|
323 |
-
# # Add index column
|
324 |
-
# ref_df_cleaned = ref_df_cleaned.with_column(pl.lit('').alias('ref_index'))
|
325 |
-
|
326 |
-
# return ref_df_cleaned
|
327 |
-
|
328 |
-
|
329 |
def extract_postcode(df, col:str) -> PandasSeries:
|
330 |
'''
|
331 |
Extract a postcode from a string column in a dataframe
|
@@ -335,7 +237,6 @@ def extract_postcode(df, col:str) -> PandasSeries:
|
|
335 |
|
336 |
return postcode_series
|
337 |
|
338 |
-
|
339 |
# Remove addresses with no numbers in at all - too high a risk of badly assigning an address
|
340 |
def check_no_number_addresses(df, in_address_series) -> PandasSeries:
|
341 |
'''
|
@@ -353,7 +254,6 @@ def check_no_number_addresses(df, in_address_series) -> PandasSeries:
|
|
353 |
|
354 |
return df
|
355 |
|
356 |
-
|
357 |
def remove_postcode(df, col:str) -> PandasSeries:
|
358 |
'''
|
359 |
Remove a postcode from a string column in a dataframe
|
|
|
49 |
|
50 |
return search_df_out, key_field, address_cols, postcode_col
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def prepare_search_address(
|
53 |
search_df: pd.DataFrame,
|
54 |
address_cols: list,
|
|
|
102 |
df[cols] = df[cols].apply(clean_col)
|
103 |
|
104 |
return df
|
105 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def _join_address(df, cols):
|
107 |
# Joining logic
|
108 |
full_address = df[cols].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
|
|
|
228 |
|
229 |
return ref_df_cleaned
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def extract_postcode(df, col:str) -> PandasSeries:
|
232 |
'''
|
233 |
Extract a postcode from a string column in a dataframe
|
|
|
237 |
|
238 |
return postcode_series
|
239 |
|
|
|
240 |
# Remove addresses with no numbers in at all - too high a risk of badly assigning an address
|
241 |
def check_no_number_addresses(df, in_address_series) -> PandasSeries:
|
242 |
'''
|
|
|
254 |
|
255 |
return df
|
256 |
|
|
|
257 |
def remove_postcode(df, col:str) -> PandasSeries:
|
258 |
'''
|
259 |
Remove a postcode from a string column in a dataframe
|