seanpedrickcase commited on
Commit
115b61f
1 Parent(s): a3c7fb0

Added AWS auth, logging, allowed for API call saves

Browse files
app.py CHANGED
@@ -3,11 +3,13 @@ from datetime import datetime
3
  from pathlib import Path
4
  import gradio as gr
5
  import pandas as pd
 
6
 
7
  from tools.matcher_funcs import run_matcher
8
- from tools.gradio import initial_data_load, ensure_output_folder_exists
9
- from tools.aws_functions import load_data_from_aws
10
  from tools.constants import output_folder
 
11
 
12
  import warnings
13
  # Remove warnings from print statements
@@ -25,6 +27,15 @@ base_folder = Path(os.getcwd())
25
 
26
  ensure_output_folder_exists(output_folder)
27
 
 
 
 
 
 
 
 
 
 
28
  # Create the gradio interface
29
  block = gr.Blocks(theme = gr.themes.Base())
30
 
@@ -35,6 +46,17 @@ with block:
35
  results_data_state = gr.State(pd.DataFrame())
36
  ref_results_data_state =gr.State(pd.DataFrame())
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  gr.Markdown(
39
  """
40
  # Address matcher
@@ -66,7 +88,7 @@ with block:
66
 
67
  with gr.Accordion("Use Addressbase API (instead of reference file)", open = True):
68
  in_api = gr.Dropdown(label="Choose API type", multiselect=False, value=None, choices=["Postcode"])#["Postcode", "UPRN"]) #choices=["Address", "Postcode", "UPRN"])
69
- in_api_key = gr.Textbox(label="Addressbase API key", type='password')
70
 
71
  with gr.Accordion("Match against reference file of addresses", open = False):
72
  in_ref = gr.File(label="Input reference addresses from file", file_count= "multiple")
@@ -81,6 +103,18 @@ with block:
81
  output_summary = gr.Textbox(label="Output summary")
82
  output_file = gr.File(label="Output file")
83
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  with gr.Tab(label="Advanced options"):
85
  with gr.Accordion(label = "AWS data access", open = False):
86
  aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
@@ -90,35 +124,46 @@ with block:
90
 
91
  aws_log_box = gr.Textbox(label="AWS data load status")
92
 
93
-
94
  ### Loading AWS data ###
95
  load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_ref, aws_log_box])
96
-
97
 
98
  # Updates to components
99
  in_file.change(fn = initial_data_load, inputs=[in_file], outputs=[output_summary, in_colnames, in_existing, data_state, results_data_state])
100
  in_ref.change(fn = initial_data_load, inputs=[in_ref], outputs=[output_summary, in_refcol, in_joincol, ref_data_state, ref_results_data_state])
101
 
102
  match_btn.click(fn = run_matcher, inputs=[in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, in_api, in_api_key],
103
- outputs=[output_summary, output_file], api_name="address")
 
104
 
105
 
106
- # Run app
107
- # If GRADIO_OUTPUT_FOLDER exists and is set to /tmp/ it means that the app is running on AWS Lambda and the queue should not be enabled.
108
-
109
- if 'GRADIO_OUTPUT_FOLDER' in os.environ:
110
- if os.environ['GRADIO_OUTPUT_FOLDER'] == '/tmp/':
111
- block.launch(ssl_verify=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  else:
113
- block.queue().launch(ssl_verify=False)
114
-
115
- block.queue().launch(ssl_verify=False)
116
-
117
- # Download OpenSSL from here:
118
- # Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d
119
- #block.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
120
- # ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
121
-
122
- # Running on local server without https
123
- #block.queue().launch(server_name="0.0.0.0", server_port=7861, ssl_verify=False)
124
 
 
3
  from pathlib import Path
4
  import gradio as gr
5
  import pandas as pd
6
+ import socket
7
 
8
  from tools.matcher_funcs import run_matcher
9
+ from tools.helper_functions import initial_data_load, ensure_output_folder_exists, get_connection_params, get_or_create_env_var, reveal_feedback_buttons
10
+ from tools.aws_functions import load_data_from_aws, upload_file_to_s3
11
  from tools.constants import output_folder
12
+ from tools.auth import authenticate_user
13
 
14
  import warnings
15
  # Remove warnings from print statements
 
27
 
28
  ensure_output_folder_exists(output_folder)
29
 
30
+ host_name = socket.gethostname()
31
+
32
+ feedback_logs_folder = 'feedback/' + today_rev + '/' + host_name + '/'
33
+ access_logs_folder = 'logs/' + today_rev + '/' + host_name + '/'
34
+ usage_logs_folder = 'usage/' + today_rev + '/' + host_name + '/'
35
+
36
+ # Launch the Gradio app
37
+ ADDRESSBASE_API_KEY = get_or_create_env_var('ADDRESSBASE_API_KEY', '')
38
+
39
  # Create the gradio interface
40
  block = gr.Blocks(theme = gr.themes.Base())
41
 
 
46
  results_data_state = gr.State(pd.DataFrame())
47
  ref_results_data_state =gr.State(pd.DataFrame())
48
 
49
+ session_hash_state = gr.State()
50
+ s3_output_folder_state = gr.State()
51
+
52
+ # Logging state
53
+ feedback_logs_state = gr.State(feedback_logs_folder + 'log.csv')
54
+ feedback_s3_logs_loc_state = gr.State(feedback_logs_folder)
55
+ access_logs_state = gr.State(access_logs_folder + 'log.csv')
56
+ access_s3_logs_loc_state = gr.State(access_logs_folder)
57
+ usage_logs_state = gr.State(usage_logs_folder + 'log.csv')
58
+ usage_s3_logs_loc_state = gr.State(usage_logs_folder)
59
+
60
  gr.Markdown(
61
  """
62
  # Address matcher
 
88
 
89
  with gr.Accordion("Use Addressbase API (instead of reference file)", open = True):
90
  in_api = gr.Dropdown(label="Choose API type", multiselect=False, value=None, choices=["Postcode"])#["Postcode", "UPRN"]) #choices=["Address", "Postcode", "UPRN"])
91
+ in_api_key = gr.Textbox(label="Addressbase API key", type='password', value = ADDRESSBASE_API_KEY)
92
 
93
  with gr.Accordion("Match against reference file of addresses", open = False):
94
  in_ref = gr.File(label="Input reference addresses from file", file_count= "multiple")
 
103
  output_summary = gr.Textbox(label="Output summary")
104
  output_file = gr.File(label="Output file")
105
 
106
+ feedback_title = gr.Markdown(value="## Please give feedback", visible=False)
107
+ feedback_radio = gr.Radio(choices=["The results were good", "The results were not good"], visible=False)
108
+ further_details_text = gr.Textbox(label="Please give more detailed feedback about the results:", visible=False)
109
+ submit_feedback_btn = gr.Button(value="Submit feedback", visible=False)
110
+
111
+ with gr.Row():
112
+ s3_logs_output_textbox = gr.Textbox(label="Feedback submission logs", visible=False)
113
+ # This keeps track of the time taken to match files for logging purposes.
114
+ estimated_time_taken_number = gr.Number(value=0.0, precision=1, visible=False)
115
+ # Invisible text box to hold the session hash/username just for logging purposes
116
+ session_hash_textbox = gr.Textbox(value="", visible=False)
117
+
118
  with gr.Tab(label="Advanced options"):
119
  with gr.Accordion(label = "AWS data access", open = False):
120
  aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
 
124
 
125
  aws_log_box = gr.Textbox(label="AWS data load status")
126
 
 
127
  ### Loading AWS data ###
128
  load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_ref, aws_log_box])
 
129
 
130
  # Updates to components
131
  in_file.change(fn = initial_data_load, inputs=[in_file], outputs=[output_summary, in_colnames, in_existing, data_state, results_data_state])
132
  in_ref.change(fn = initial_data_load, inputs=[in_ref], outputs=[output_summary, in_refcol, in_joincol, ref_data_state, ref_results_data_state])
133
 
134
  match_btn.click(fn = run_matcher, inputs=[in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, in_api, in_api_key],
135
+ outputs=[output_summary, output_file, estimated_time_taken_number], api_name="address").\
136
+ then(fn = reveal_feedback_buttons, outputs=[feedback_radio, further_details_text, submit_feedback_btn, feedback_title])
137
 
138
 
139
+ # Get connection details on app load
140
+ block.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox])
141
+
142
+ # Log usernames and times of access to file (to know who is using the app when running on AWS)
143
+ access_callback = gr.CSVLogger()
144
+ access_callback.setup([session_hash_textbox], access_logs_folder)
145
+ session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
146
+ then(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
147
+
148
+ # User submitted feedback for pdf redactions
149
+ feedback_callback = gr.CSVLogger()
150
+ feedback_callback.setup([feedback_radio, further_details_text, in_file], feedback_logs_folder)
151
+ submit_feedback_btn.click(lambda *args: feedback_callback.flag(list(args)), [feedback_radio, further_details_text, in_file], None, preprocess=False).\
152
+ then(fn = upload_file_to_s3, inputs=[feedback_logs_state, feedback_s3_logs_loc_state], outputs=[further_details_text])
153
+
154
+ # Log processing time/token usage when making a query
155
+ usage_callback = gr.CSVLogger()
156
+ usage_callback.setup([session_hash_textbox, in_file, estimated_time_taken_number], usage_logs_folder)
157
+ estimated_time_taken_number.change(lambda *args: usage_callback.flag(list(args)), [session_hash_textbox, in_file, estimated_time_taken_number], None, preprocess=False).\
158
+ then(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
159
+
160
+ # Launch the Gradio app
161
+ COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
162
+ print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
163
+
164
+ if __name__ == "__main__":
165
+ if os.environ['COGNITO_AUTH'] == "1":
166
+ block.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
167
  else:
168
+ block.queue().launch(show_error=True, inbrowser=True, max_file_size='50mb')
 
 
 
 
 
 
 
 
 
 
169
 
tools/auth.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ from tools.helper_functions import get_or_create_env_var
3
+
4
+ client_id = get_or_create_env_var('AWS_CLIENT_ID', '') # This client id is borrowed from async gradio app client
5
+ print(f'The value of AWS_CLIENT_ID is {client_id}')
6
+
7
+ user_pool_id = get_or_create_env_var('AWS_USER_POOL_ID', '')
8
+ print(f'The value of AWS_USER_POOL_ID is {user_pool_id}')
9
+
10
+ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=client_id):
11
+ """Authenticates a user against an AWS Cognito user pool.
12
+
13
+ Args:
14
+ user_pool_id (str): The ID of the Cognito user pool.
15
+ client_id (str): The ID of the Cognito user pool client.
16
+ username (str): The username of the user.
17
+ password (str): The password of the user.
18
+
19
+ Returns:
20
+ bool: True if the user is authenticated, False otherwise.
21
+ """
22
+
23
+ client = boto3.client('cognito-idp') # Cognito Identity Provider client
24
+
25
+ try:
26
+ response = client.initiate_auth(
27
+ AuthFlow='USER_PASSWORD_AUTH',
28
+ AuthParameters={
29
+ 'USERNAME': username,
30
+ 'PASSWORD': password,
31
+ },
32
+ ClientId=client_id
33
+ )
34
+
35
+ # If successful, you'll receive an AuthenticationResult in the response
36
+ if response.get('AuthenticationResult'):
37
+ return True
38
+ else:
39
+ return False
40
+
41
+ except client.exceptions.NotAuthorizedException:
42
+ return False
43
+ except client.exceptions.UserNotFoundException:
44
+ return False
45
+ except Exception as e:
46
+ print(f"An error occurred: {e}")
47
+ return False
tools/aws_functions.py CHANGED
@@ -1,37 +1,46 @@
1
- from typing import Type
2
  import pandas as pd
3
  import boto3
4
  import tempfile
5
  import os
 
6
 
7
  PandasDataFrame = Type[pd.DataFrame]
8
 
9
- try:
10
- session = boto3.Session()
11
- bucket_name = os.environ['ADDRESS_MATCHER_BUCKET']
12
- except Exception as e:
13
- bucket_name = ''
14
- print(e)
15
-
16
- def get_assumed_role_info():
17
- sts = boto3.client('sts', region_name='eu-west-2', endpoint_url='https://sts.eu-west-2.amazonaws.com')
18
- response = sts.get_caller_identity()
19
-
20
- # Extract ARN of the assumed role
21
- assumed_role_arn = response['Arn']
22
-
23
- # Extract the name of the assumed role from the ARN
24
- assumed_role_name = assumed_role_arn.split('/')[-1]
25
-
26
- return assumed_role_arn, assumed_role_name
27
-
28
- try:
29
- assumed_role_arn, assumed_role_name = get_assumed_role_info()
30
-
31
- print("Assumed Role ARN:", assumed_role_arn)
32
- print("Assumed Role Name:", assumed_role_name)
33
- except Exception as e:
34
- print(e)
 
 
 
 
 
 
 
 
35
 
36
  # Download direct from S3 - requires login credentials
37
  def download_file_from_s3(bucket_name, key, local_file_path):
@@ -101,8 +110,6 @@ def download_files_from_s3(bucket_name, s3_folder, local_folder, filenames):
101
  except Exception as e:
102
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
103
 
104
-
105
-
106
  def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_name):
107
 
108
  temp_dir = tempfile.mkdtemp()
@@ -154,3 +161,42 @@ def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_
154
 
155
  return files, out_message
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, List
2
  import pandas as pd
3
  import boto3
4
  import tempfile
5
  import os
6
+ from tools.helper_functions import get_or_create_env_var
7
 
8
  PandasDataFrame = Type[pd.DataFrame]
9
 
10
+ # Get AWS credentials if required
11
+ bucket_name=""
12
+ aws_var = "RUN_AWS_FUNCTIONS"
13
+ aws_var_default = "0"
14
+ aws_var_val = get_or_create_env_var(aws_var, aws_var_default)
15
+ print(f'The value of {aws_var} is {aws_var_val}')
16
+
17
+ if aws_var_val == "1":
18
+ try:
19
+ session = boto3.Session()
20
+ bucket_name = os.environ['ADDRESS_MATCHER_BUCKET']
21
+ except Exception as e:
22
+ bucket_name = ''
23
+ print(e)
24
+
25
+ def get_assumed_role_info():
26
+ sts = boto3.client('sts', region_name='eu-west-2', endpoint_url='https://sts.eu-west-2.amazonaws.com')
27
+ response = sts.get_caller_identity()
28
+
29
+ # Extract ARN of the assumed role
30
+ assumed_role_arn = response['Arn']
31
+
32
+ # Extract the name of the assumed role from the ARN
33
+ assumed_role_name = assumed_role_arn.split('/')[-1]
34
+
35
+ return assumed_role_arn, assumed_role_name
36
+
37
+ try:
38
+ assumed_role_arn, assumed_role_name = get_assumed_role_info()
39
+
40
+ print("Assumed Role ARN:", assumed_role_arn)
41
+ print("Assumed Role Name:", assumed_role_name)
42
+ except Exception as e:
43
+ print(e)
44
 
45
  # Download direct from S3 - requires login credentials
46
  def download_file_from_s3(bucket_name, key, local_file_path):
 
110
  except Exception as e:
111
  print(f"Error downloading 's3://{bucket_name}/{object_key}':", e)
112
 
 
 
113
  def load_data_from_aws(in_aws_keyword_file, aws_password="", bucket_name=bucket_name):
114
 
115
  temp_dir = tempfile.mkdtemp()
 
161
 
162
  return files, out_message
163
 
164
+ def upload_file_to_s3(local_file_paths:List[str], s3_key:str, s3_bucket:str=bucket_name):
165
+ """
166
+ Uploads a file from local machine to Amazon S3.
167
+
168
+ Args:
169
+ - local_file_path: Local file path(s) of the file(s) to upload.
170
+ - s3_key: Key (path) to the file in the S3 bucket.
171
+ - s3_bucket: Name of the S3 bucket.
172
+
173
+ Returns:
174
+ - Message as variable/printed to console
175
+ """
176
+ final_out_message = []
177
+
178
+ s3_client = boto3.client('s3')
179
+
180
+ if isinstance(local_file_paths, str):
181
+ local_file_paths = [local_file_paths]
182
+
183
+ for file in local_file_paths:
184
+ try:
185
+ # Get file name off file path
186
+ file_name = os.path.basename(file)
187
+
188
+ s3_key_full = s3_key + file_name
189
+ print("S3 key: ", s3_key_full)
190
+
191
+ s3_client.upload_file(file, s3_bucket, s3_key_full)
192
+ out_message = "File " + file_name + " uploaded successfully!"
193
+ print(out_message)
194
+
195
+ except Exception as e:
196
+ out_message = f"Error uploading file(s): {e}"
197
+ print(out_message)
198
+
199
+ final_out_message.append(out_message)
200
+ final_out_message_str = '\n'.join(final_out_message)
201
+
202
+ return final_out_message_str
tools/{gradio.py → helper_functions.py} RENAMED
@@ -1,6 +1,25 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def detect_file_type(filename):
6
  """Detect the file type based on its extension."""
@@ -70,7 +89,94 @@ def dummy_function(in_colnames):
70
  """
71
  return None
72
 
 
 
 
73
 
74
  def clear_inputs(in_file, in_ref, in_text):
75
  return gr.File(value=[]), gr.File(value=[]), gr.Textbox(value='')
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import os
4
+ import re
5
+
6
+ def get_or_create_env_var(var_name, default_value):
7
+ # Get the environment variable if it exists
8
+ value = os.environ.get(var_name)
9
+
10
+ # If it doesn't exist, set it to the default value
11
+ if value is None:
12
+ os.environ[var_name] = default_value
13
+ value = default_value
14
+
15
+ return value
16
+
17
+ # Retrieving or setting output folder
18
+ env_var_name = 'GRADIO_OUTPUT_FOLDER'
19
+ default_value = 'output/'
20
+
21
+ output_folder = get_or_create_env_var(env_var_name, default_value)
22
+ print(f'The value of {env_var_name} is {output_folder}')
23
 
24
  def detect_file_type(filename):
25
  """Detect the file type based on its extension."""
 
89
  """
90
  return None
91
 
92
+ # Upon running a process, the feedback buttons are revealed
93
+ def reveal_feedback_buttons():
94
+ return gr.Radio(visible=True), gr.Textbox(visible=True), gr.Button(visible=True), gr.Markdown(visible=True)
95
 
96
  def clear_inputs(in_file, in_ref, in_text):
97
  return gr.File(value=[]), gr.File(value=[]), gr.Textbox(value='')
98
 
99
+ ## Get final processing time for logs:
100
+ def sum_numbers_before_seconds(string):
101
+ """Extracts numbers that precede the word 'seconds' from a string and adds them up.
102
+
103
+ Args:
104
+ string: The input string.
105
+
106
+ Returns:
107
+ The sum of all numbers before 'seconds' in the string.
108
+ """
109
+
110
+ # Extract numbers before 'seconds' using regular expression
111
+ numbers = re.findall(r'\d+(\.\d+)?\s*seconds', string)
112
+
113
+ # Extract the numbers from the matches
114
+ numbers = [float(num.split()[0]) for num in numbers]
115
+
116
+ # Sum up the extracted numbers
117
+ sum_of_numbers = sum(numbers)
118
+
119
+ return sum_of_numbers
120
+
121
+ async def get_connection_params(request: gr.Request):
122
+ base_folder = ""
123
+
124
+ if request:
125
+ #print("request user:", request.username)
126
+
127
+ #request_data = await request.json() # Parse JSON body
128
+ #print("All request data:", request_data)
129
+ #context_value = request_data.get('context')
130
+ #if 'context' in request_data:
131
+ # print("Request context dictionary:", request_data['context'])
132
+
133
+ # print("Request headers dictionary:", request.headers)
134
+ # print("All host elements", request.client)
135
+ # print("IP address:", request.client.host)
136
+ # print("Query parameters:", dict(request.query_params))
137
+ # To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
138
+ #print("Request dictionary to object:", request.request.body())
139
+ print("Session hash:", request.session_hash)
140
+
141
+ # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
142
+ CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
143
+ #print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
144
+
145
+ # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
146
+ CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
147
+ #print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
148
+
149
+ if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
150
+ if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
151
+ supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
152
+ if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
153
+ print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
154
+ else:
155
+ raise(ValueError, "Custom Cloudfront header value does not match expected value.")
156
+
157
+ # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
158
+
159
+ if request.username:
160
+ out_session_hash = request.username
161
+ base_folder = "user-files/"
162
+ print("Request username found:", out_session_hash)
163
+
164
+ elif 'x-cognito-id' in request.headers:
165
+ out_session_hash = request.headers['x-cognito-id']
166
+ base_folder = "user-files/"
167
+ print("Cognito ID found:", out_session_hash)
168
+
169
+ else:
170
+ out_session_hash = request.session_hash
171
+ base_folder = "temp-files/"
172
+ # print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
173
+
174
+ output_folder = base_folder + out_session_hash + "/"
175
+ #if bucket_name:
176
+ # print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
177
+
178
+ return out_session_hash, output_folder, out_session_hash
179
+ else:
180
+ print("No session parameters found.")
181
+ return "",""
182
+
tools/matcher_funcs.py CHANGED
@@ -33,7 +33,7 @@ from tools.standardise import standardise_wrapper_func
33
  ### Predict function for imported model
34
  from tools.model_predict import full_predict_func, full_predict_torch, post_predict_clean
35
  from tools.recordlinkage_funcs import score_based_match
36
- from tools.gradio import initial_data_load
37
 
38
  # API functions
39
  from tools.addressbase_api_funcs import places_api_query
@@ -108,10 +108,13 @@ def filter_not_matched(
108
 
109
  return search_df.iloc[np.where(~matched)[0]]
110
 
111
- def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, progress=gr.Progress()):
 
 
 
112
  if in_api_key == "":
113
  print ("No API key provided, please provide one to continue")
114
- return Matcher
115
  else:
116
  # Call the API
117
  #Matcher.ref_df = pd.DataFrame()
@@ -119,7 +122,7 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
119
  # Check if the ref_df file already exists
120
  def check_and_create_api_folder():
121
  # Check if the environmental variable is available
122
- file_path = os.environ.get('ADDRESSBASE_API_OUT') # Replace 'YOUR_ENV_VARIABLE_NAME' with the name of your environmental variable
123
 
124
  if file_path is None:
125
  # Environmental variable is not set
@@ -145,11 +148,13 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
145
  api_ref_save_loc = api_output_folder + search_file_name_without_extension + "_api_" + today_month_rev + "_" + query_type + "_ckpt"
146
  print("API reference save location: ", api_ref_save_loc)
147
 
 
 
148
  # Allow for csv, parquet and gzipped csv files
149
  if os.path.isfile(api_ref_save_loc + ".csv"):
150
  print("API reference CSV file found")
151
  Matcher.ref_df = pd.read_csv(api_ref_save_loc + ".csv")
152
- elif os.path.isfile(api_ref_save_loc + ".parquet"):
153
  print("API reference Parquet file found")
154
  Matcher.ref_df = pd.read_parquet(api_ref_save_loc + ".parquet")
155
  elif os.path.isfile(api_ref_save_loc + ".csv.gz"):
@@ -350,21 +355,23 @@ def run_all_api_calls(in_api_key:str, Matcher:MatcherClass, query_type:str, prog
350
  # Matcher.ref_df = Matcher.ref_df.loc[Matcher.ref_df["LOCAL_CUSTODIAN_CODE"] != 7655,:]
351
 
352
  if save_file:
 
353
  print("Saving reference file to: " + api_ref_save_loc[:-5] + ".parquet")
354
  Matcher.ref_df.to_parquet(output_folder + api_ref_save_loc + ".parquet", index=False) # Save checkpoint as well
355
- Matcher.ref_df.to_parquet(output_folder + api_ref_save_loc[:-5] + ".parquet", index=False)
356
 
357
  if Matcher.ref_df.empty:
358
  print ("No reference data found with API")
359
  return Matcher
360
 
361
- return Matcher
362
 
363
- def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame, in_ref:List[str], in_refcol:List[str], in_api:List[str], in_api_key:str, query_type:str, progress=gr.Progress()):
364
  '''
365
  Check for reference address data, do some preprocessing, and load in from the Addressbase API if required.
366
  '''
367
-
 
368
  # Check if reference data loaded, bring in if already there
369
  if not ref_data_state.empty:
370
  Matcher.ref_df = ref_data_state
@@ -382,10 +389,10 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
382
  if not in_ref:
383
  if in_api==False:
384
  print ("No reference file provided, please provide one to continue")
385
- return Matcher
386
  # Check if api call required and api key is provided
387
  else:
388
- Matcher = run_all_api_calls(in_api_key, Matcher, query_type)
389
 
390
  else:
391
  Matcher.ref_name = get_file_name(in_ref[0].name)
@@ -402,9 +409,7 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
402
 
403
  Matcher.ref_df = pd.concat([Matcher.ref_df, temp_ref_file])
404
 
405
- # For the neural net model to work, the llpg columns have to be in the LPI format (e.g. with columns SaoText, SaoStartNumber etc. Here we check if we have that format.
406
-
407
-
408
 
409
  if 'Address_LPI' in Matcher.ref_df.columns:
410
  Matcher.ref_df = Matcher.ref_df.rename(columns={
@@ -475,9 +480,9 @@ def check_ref_data_exists(Matcher:MatcherClass, ref_data_state:PandasDataFrame,
475
  Matcher.ref_df = Matcher.ref_df.reset_index() #.drop(["index","level_0"], axis = 1, errors="ignore").reset_index().drop(["index","level_0"], axis = 1, errors="ignore")
476
  Matcher.ref_df.index.name = 'index'
477
 
478
- return Matcher
479
 
480
- def check_match_data_filter(Matcher:MatcherClass, data_state:PandasDataFrame, results_data_state:PandasDataFrame, in_file:List[str], in_text:str, in_colnames:List[str], in_joincol:List[str], in_existing:List[str], in_api:List[str]):
481
  '''
482
  Check if data to be matched exists. Filter it according to which records are relevant in the reference dataset
483
  '''
@@ -654,6 +659,8 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
654
  '''
655
  Load in user inputs from the Gradio interface. Convert all input types (single address, or csv input) into standardised data format that can be used downstream for the fuzzy matching.
656
  '''
 
 
657
  today_rev = datetime.now().strftime("%Y%m%d")
658
 
659
  # Abort flag for if it's not even possible to attempt the first stage of the match for some reason
@@ -662,16 +669,15 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
662
  ### ref_df FILES ###
663
  # If not an API call, run this first
664
  if not in_api:
665
- Matcher = check_ref_data_exists(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
666
 
667
  ### MATCH/SEARCH FILES ###
668
  # If doing API calls, we need to know the search data before querying for specific addresses/postcodes
669
- Matcher = check_match_data_filter(Matcher, data_state, results_data_state, in_file, in_text, in_colnames, in_joincol, in_existing, in_api)
670
 
671
  # If an API call, ref_df data is loaded after
672
  if in_api:
673
-
674
- Matcher = check_ref_data_exists(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
675
 
676
  print("Shape of ref_df after filtering is: ", Matcher.ref_df.shape)
677
  print("Shape of search_df after filtering is: ", Matcher.search_df.shape)
@@ -682,23 +688,31 @@ def load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state,
682
  Matcher.match_results_output.to_csv(Matcher.match_outputs_name, index = None)
683
  Matcher.results_on_orig_df.to_csv(Matcher.results_orig_df_name, index = None)
684
 
685
- return Matcher
686
 
687
  # Run whole matcher process
688
  def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame, results_data_state:PandasDataFrame, ref_data_state:PandasDataFrame, in_colnames:List[str], in_refcol:List[str], in_joincol:List[str], in_existing:List[str], in_api:str, in_api_key:str, InitMatch:MatcherClass = InitMatch, progress=gr.Progress()):
689
  '''
690
  Split search and reference data into batches. Loop and run through the match script for each batch of data.
691
  '''
 
 
 
692
 
693
  overall_tic = time.perf_counter()
694
 
695
  # Load in initial data. This will filter to relevant addresses in the search and reference datasets that can potentially be matched, and will pull in API data if asked for.
696
- InitMatch = load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, InitMatch, in_api, in_api_key)
 
 
 
697
 
698
  if InitMatch.search_df.empty or InitMatch.ref_df.empty:
699
  out_message = "Nothing to match!"
700
  print(out_message)
701
- return out_message, [InitMatch.results_orig_df_name, InitMatch.match_outputs_name]
 
 
702
 
703
  # Run initial address preparation and standardisation processes
704
  # Prepare address format
@@ -801,7 +815,7 @@ def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame
801
  "Excluded from search":False,
802
  "Matched with reference address":False})
803
  else:
804
- summary_of_summaries, BatchMatch_out = run_match_batch(BatchMatch, n, number_of_batches)
805
 
806
  OutputMatch = combine_two_matches(OutputMatch, BatchMatch_out, "All up to and including batch " + str(n+1))
807
 
@@ -837,7 +851,13 @@ def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame
837
 
838
  final_summary = fuzzy_not_std_summary + "\n" + fuzzy_std_summary + "\n" + nnet_std_summary + "\n" + time_out
839
 
840
- return final_summary, [OutputMatch.results_orig_df_name, OutputMatch.match_outputs_name]
 
 
 
 
 
 
841
 
842
  # Run a match run for a single batch
843
  def create_simple_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:int, ref_batch_size:int):
@@ -963,7 +983,7 @@ def create_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:i
963
 
964
  return lengths_df
965
 
966
- def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, progress=gr.Progress()):
967
  '''
968
  Over-arching function for running a single batch of data through the full matching process. Calls fuzzy matching, then neural network match functions in order. It outputs a summary of the match, and a MatcherClass with the matched data included.
969
  '''
@@ -979,7 +999,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
979
 
980
  ''' Run fuzzy match on non-standardised dataset '''
981
 
982
- FuzzyNotStdMatch = orchestrate_match_run(Matcher = copy.copy(InitialMatch), standardise = False, nnet = False, file_stub= "not_std_", df_name = df_name)
983
 
984
  if FuzzyNotStdMatch.abort_flag == True:
985
  message = "Nothing to match! Aborting address check."
@@ -999,7 +1019,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
999
  progress(.25, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Fuzzy match - standardised dataset")
1000
  df_name = "Fuzzy standardised"
1001
 
1002
- FuzzyStdMatch = orchestrate_match_run(Matcher = copy.copy(FuzzyNotStdMatch), standardise = True, nnet = False, file_stub= "std_", df_name = df_name)
1003
  FuzzyStdMatch = combine_two_matches(FuzzyNotStdMatch, FuzzyStdMatch, df_name)
1004
 
1005
  ''' Continue if reference file in correct format, and neural net model exists. Also if data not too long '''
@@ -1022,7 +1042,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
1022
  progress(.50, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - non-standardised dataset")
1023
  df_name = "Neural net not standardised"
1024
 
1025
- FuzzyNNetNotStdMatch = orchestrate_match_run(Matcher = copy.copy(FuzzyStdMatch), standardise = False, nnet = True, file_stub= "nnet_not_std_", df_name = df_name)
1026
  FuzzyNNetNotStdMatch = combine_two_matches(FuzzyStdMatch, FuzzyNNetNotStdMatch, df_name)
1027
 
1028
  if (len(FuzzyNNetNotStdMatch.search_df_not_matched) == 0):
@@ -1035,7 +1055,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
1035
  progress(.75, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - standardised dataset")
1036
  df_name = "Neural net standardised"
1037
 
1038
- FuzzyNNetStdMatch = orchestrate_match_run(Matcher = copy.copy(FuzzyNNetNotStdMatch), standardise = True, nnet = True, file_stub= "nnet_std_", df_name = df_name)
1039
  FuzzyNNetStdMatch = combine_two_matches(FuzzyNNetNotStdMatch, FuzzyNNetStdMatch, df_name)
1040
 
1041
  if run_fuzzy_match == False:
@@ -1052,7 +1072,7 @@ def run_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, p
1052
  return summary_of_summaries, FuzzyNNetStdMatch
1053
 
1054
  # Overarching functions
1055
- def orchestrate_match_run(Matcher, standardise = False, nnet = False, file_stub= "not_std_", df_name = "Fuzzy not standardised"):
1056
 
1057
  today_rev = datetime.now().strftime("%Y%m%d")
1058
 
@@ -1463,7 +1483,6 @@ def full_nn_match(ref_address_cols:List[str],
1463
 
1464
  return match_results_output_final_three, results_on_orig_df, summary_three, predict_df
1465
 
1466
-
1467
  # Combiner/summary functions
1468
  def combine_dfs_and_remove_dups(orig_df:PandasDataFrame, new_df:PandasDataFrame, index_col:str = "search_orig_address", match_address_series:str = "full_match", keep_only_duplicated:bool = False) -> PandasDataFrame:
1469
 
 
33
  ### Predict function for imported model
34
  from tools.model_predict import full_predict_func, full_predict_torch, post_predict_clean
35
  from tools.recordlinkage_funcs import score_based_match
36
+ from tools.helper_functions import initial_data_load, sum_numbers_before_seconds
37
 
38
  # API functions
39
  from tools.addressbase_api_funcs import places_api_query
 
108
 
109
  return search_df.iloc[np.where(~matched)[0]]
110
 
111
+ def query_addressbase_api(in_api_key:str, Matcher:MatcherClass, query_type:str, progress=gr.Progress()):
112
+
113
+ final_api_output_file_name = ""
114
+
115
  if in_api_key == "":
116
  print ("No API key provided, please provide one to continue")
117
+ return Matcher, final_api_output_file_name
118
  else:
119
  # Call the API
120
  #Matcher.ref_df = pd.DataFrame()
 
122
  # Check if the ref_df file already exists
123
  def check_and_create_api_folder():
124
  # Check if the environmental variable is available
125
+ file_path = os.environ.get('ADDRESSBASE_API_OUT')
126
 
127
  if file_path is None:
128
  # Environmental variable is not set
 
148
  api_ref_save_loc = api_output_folder + search_file_name_without_extension + "_api_" + today_month_rev + "_" + query_type + "_ckpt"
149
  print("API reference save location: ", api_ref_save_loc)
150
 
151
+ final_api_output_file_name = api_ref_save_loc + ".parquet"
152
+
153
  # Allow for csv, parquet and gzipped csv files
154
  if os.path.isfile(api_ref_save_loc + ".csv"):
155
  print("API reference CSV file found")
156
  Matcher.ref_df = pd.read_csv(api_ref_save_loc + ".csv")
157
+ elif os.path.isfile(final_api_output_file_name):
158
  print("API reference Parquet file found")
159
  Matcher.ref_df = pd.read_parquet(api_ref_save_loc + ".parquet")
160
  elif os.path.isfile(api_ref_save_loc + ".csv.gz"):
 
355
  # Matcher.ref_df = Matcher.ref_df.loc[Matcher.ref_df["LOCAL_CUSTODIAN_CODE"] != 7655,:]
356
 
357
  if save_file:
358
+ final_api_output_file_name = output_folder + api_ref_save_loc[:-5] + ".parquet"
359
  print("Saving reference file to: " + api_ref_save_loc[:-5] + ".parquet")
360
  Matcher.ref_df.to_parquet(output_folder + api_ref_save_loc + ".parquet", index=False) # Save checkpoint as well
361
+ Matcher.ref_df.to_parquet(final_api_output_file_name, index=False)
362
 
363
  if Matcher.ref_df.empty:
364
  print ("No reference data found with API")
365
  return Matcher
366
 
367
+ return Matcher, final_api_output_file_name
368
 
369
+ def load_ref_data(Matcher:MatcherClass, ref_data_state:PandasDataFrame, in_ref:List[str], in_refcol:List[str], in_api:List[str], in_api_key:str, query_type:str, progress=gr.Progress()):
370
  '''
371
  Check for reference address data, do some preprocessing, and load in from the Addressbase API if required.
372
  '''
373
+ final_api_output_file_name = ""
374
+
375
  # Check if reference data loaded, bring in if already there
376
  if not ref_data_state.empty:
377
  Matcher.ref_df = ref_data_state
 
389
  if not in_ref:
390
  if in_api==False:
391
  print ("No reference file provided, please provide one to continue")
392
+ return Matcher, final_api_output_file_name
393
  # Check if api call required and api key is provided
394
  else:
395
+ Matcher, final_api_output_file_name = query_addressbase_api(in_api_key, Matcher, query_type)
396
 
397
  else:
398
  Matcher.ref_name = get_file_name(in_ref[0].name)
 
409
 
410
  Matcher.ref_df = pd.concat([Matcher.ref_df, temp_ref_file])
411
 
412
+ # For the neural net model to work, the llpg columns have to be in the LPI format (e.g. with columns SaoText, SaoStartNumber etc. Here we check if we have that format.
 
 
413
 
414
  if 'Address_LPI' in Matcher.ref_df.columns:
415
  Matcher.ref_df = Matcher.ref_df.rename(columns={
 
480
  Matcher.ref_df = Matcher.ref_df.reset_index() #.drop(["index","level_0"], axis = 1, errors="ignore").reset_index().drop(["index","level_0"], axis = 1, errors="ignore")
481
  Matcher.ref_df.index.name = 'index'
482
 
483
+ return Matcher, final_api_output_file_name
484
 
485
+ def load_match_data_and_filter(Matcher:MatcherClass, data_state:PandasDataFrame, results_data_state:PandasDataFrame, in_file:List[str], in_text:str, in_colnames:List[str], in_joincol:List[str], in_existing:List[str], in_api:List[str]):
486
  '''
487
  Check if data to be matched exists. Filter it according to which records are relevant in the reference dataset
488
  '''
 
659
  '''
660
  Load in user inputs from the Gradio interface. Convert all input types (single address, or csv input) into standardised data format that can be used downstream for the fuzzy matching.
661
  '''
662
+ final_api_output_file_name = ""
663
+
664
  today_rev = datetime.now().strftime("%Y%m%d")
665
 
666
  # Abort flag for if it's not even possible to attempt the first stage of the match for some reason
 
669
  ### ref_df FILES ###
670
  # If not an API call, run this first
671
  if not in_api:
672
+ Matcher, final_api_output_file_name = load_ref_data(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
673
 
674
  ### MATCH/SEARCH FILES ###
675
  # If doing API calls, we need to know the search data before querying for specific addresses/postcodes
676
+ Matcher = load_match_data_and_filter(Matcher, data_state, results_data_state, in_file, in_text, in_colnames, in_joincol, in_existing, in_api)
677
 
678
  # If an API call, ref_df data is loaded after
679
  if in_api:
680
+ Matcher, final_api_output_file_name = load_ref_data(Matcher, ref_data_state, in_ref, in_refcol, in_api, in_api_key, query_type=in_api)
 
681
 
682
  print("Shape of ref_df after filtering is: ", Matcher.ref_df.shape)
683
  print("Shape of search_df after filtering is: ", Matcher.search_df.shape)
 
688
  Matcher.match_results_output.to_csv(Matcher.match_outputs_name, index = None)
689
  Matcher.results_on_orig_df.to_csv(Matcher.results_orig_df_name, index = None)
690
 
691
+ return Matcher, final_api_output_file_name
692
 
693
  # Run whole matcher process
694
  def run_matcher(in_text:str, in_file:str, in_ref:str, data_state:PandasDataFrame, results_data_state:PandasDataFrame, ref_data_state:PandasDataFrame, in_colnames:List[str], in_refcol:List[str], in_joincol:List[str], in_existing:List[str], in_api:str, in_api_key:str, InitMatch:MatcherClass = InitMatch, progress=gr.Progress()):
695
  '''
696
  Split search and reference data into batches. Loop and run through the match script for each batch of data.
697
  '''
698
+ output_files = []
699
+
700
+ estimate_total_processing_time = 0.0
701
 
702
  overall_tic = time.perf_counter()
703
 
704
  # Load in initial data. This will filter to relevant addresses in the search and reference datasets that can potentially be matched, and will pull in API data if asked for.
705
+ InitMatch, final_api_output_file_name = load_matcher_data(in_text, in_file, in_ref, data_state, results_data_state, ref_data_state, in_colnames, in_refcol, in_joincol, in_existing, InitMatch, in_api, in_api_key)
706
+
707
+ if final_api_output_file_name:
708
+ output_files.append(final_api_output_file_name)
709
 
710
  if InitMatch.search_df.empty or InitMatch.ref_df.empty:
711
  out_message = "Nothing to match!"
712
  print(out_message)
713
+
714
+ output_files.extend([InitMatch.results_orig_df_name, InitMatch.match_outputs_name])
715
+ return out_message, output_files, estimate_total_processing_time
716
 
717
  # Run initial address preparation and standardisation processes
718
  # Prepare address format
 
815
  "Excluded from search":False,
816
  "Matched with reference address":False})
817
  else:
818
+ summary_of_summaries, BatchMatch_out = run_single_match_batch(BatchMatch, n, number_of_batches)
819
 
820
  OutputMatch = combine_two_matches(OutputMatch, BatchMatch_out, "All up to and including batch " + str(n+1))
821
 
 
851
 
852
  final_summary = fuzzy_not_std_summary + "\n" + fuzzy_std_summary + "\n" + nnet_std_summary + "\n" + time_out
853
 
854
+
855
+
856
+ estimate_total_processing_time = sum_numbers_before_seconds(time_out)
857
+ print("Estimated total processing time:", str(estimate_total_processing_time))
858
+
859
+ output_files.extend([OutputMatch.results_orig_df_name, OutputMatch.match_outputs_name])
860
+ return final_summary, output_files, estimate_total_processing_time
861
 
862
  # Run a match run for a single batch
863
  def create_simple_batch_ranges(df:PandasDataFrame, ref_df:PandasDataFrame, batch_size:int, ref_batch_size:int):
 
983
 
984
  return lengths_df
985
 
986
+ def run_single_match_batch(InitialMatch:MatcherClass, batch_n:int, total_batches:int, progress=gr.Progress()):
987
  '''
988
  Over-arching function for running a single batch of data through the full matching process. Calls fuzzy matching, then neural network match functions in order. It outputs a summary of the match, and a MatcherClass with the matched data included.
989
  '''
 
999
 
1000
  ''' Run fuzzy match on non-standardised dataset '''
1001
 
1002
+ FuzzyNotStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(InitialMatch), standardise = False, nnet = False, file_stub= "not_std_", df_name = df_name)
1003
 
1004
  if FuzzyNotStdMatch.abort_flag == True:
1005
  message = "Nothing to match! Aborting address check."
 
1019
  progress(.25, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Fuzzy match - standardised dataset")
1020
  df_name = "Fuzzy standardised"
1021
 
1022
+ FuzzyStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyNotStdMatch), standardise = True, nnet = False, file_stub= "std_", df_name = df_name)
1023
  FuzzyStdMatch = combine_two_matches(FuzzyNotStdMatch, FuzzyStdMatch, df_name)
1024
 
1025
  ''' Continue if reference file in correct format, and neural net model exists. Also if data not too long '''
 
1042
  progress(.50, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - non-standardised dataset")
1043
  df_name = "Neural net not standardised"
1044
 
1045
+ FuzzyNNetNotStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyStdMatch), standardise = False, nnet = True, file_stub= "nnet_not_std_", df_name = df_name)
1046
  FuzzyNNetNotStdMatch = combine_two_matches(FuzzyStdMatch, FuzzyNNetNotStdMatch, df_name)
1047
 
1048
  if (len(FuzzyNNetNotStdMatch.search_df_not_matched) == 0):
 
1055
  progress(.75, desc="Batch " + str(batch_n+1) + " of " + str(total_batches) + ". Neural net - standardised dataset")
1056
  df_name = "Neural net standardised"
1057
 
1058
+ FuzzyNNetStdMatch = orchestrate_single_match_batch(Matcher = copy.copy(FuzzyNNetNotStdMatch), standardise = True, nnet = True, file_stub= "nnet_std_", df_name = df_name)
1059
  FuzzyNNetStdMatch = combine_two_matches(FuzzyNNetNotStdMatch, FuzzyNNetStdMatch, df_name)
1060
 
1061
  if run_fuzzy_match == False:
 
1072
  return summary_of_summaries, FuzzyNNetStdMatch
1073
 
1074
  # Overarching functions
1075
+ def orchestrate_single_match_batch(Matcher, standardise = False, nnet = False, file_stub= "not_std_", df_name = "Fuzzy not standardised"):
1076
 
1077
  today_rev = datetime.now().strftime("%Y%m%d")
1078
 
 
1483
 
1484
  return match_results_output_final_three, results_on_orig_df, summary_three, predict_df
1485
 
 
1486
  # Combiner/summary functions
1487
  def combine_dfs_and_remove_dups(orig_df:PandasDataFrame, new_df:PandasDataFrame, index_col:str = "search_orig_address", match_address_series:str = "full_match", keep_only_duplicated:bool = False) -> PandasDataFrame:
1488
 
tools/preparation.py CHANGED
@@ -49,49 +49,6 @@ def prepare_search_address_string(
49
 
50
  return search_df_out, key_field, address_cols, postcode_col
51
 
52
- # def prepare_search_address(
53
- # search_df: pd.DataFrame,
54
- # address_cols: list,
55
- # postcode_col: list,
56
- # key_col: str
57
- # ) -> Tuple[pd.DataFrame, str]:
58
-
59
- # # Validate inputs
60
- # if not isinstance(search_df, pd.DataFrame):
61
- # raise TypeError("search_df must be a Pandas DataFrame")
62
-
63
- # if not isinstance(address_cols, list):
64
- # raise TypeError("address_cols must be a list")
65
-
66
- # if not isinstance(postcode_col, list):
67
- # raise TypeError("postcode_col must be a list")
68
-
69
- # if not isinstance(key_col, str):
70
- # raise TypeError("key_col must be a string")
71
-
72
- # # Clean address columns
73
- # clean_addresses = _clean_columns(search_df, address_cols)
74
-
75
- # # Join address columns into one
76
- # full_addresses = _join_address(clean_addresses, address_cols)
77
-
78
- # # Add postcode column
79
- # full_df = _add_postcode_column(full_addresses, postcode_col)
80
-
81
- # # Remove postcode from main address if there was only one column in the input
82
- # if postcode_col == "full_address_postcode":
83
- # # Remove postcode from address
84
- # address_series = remove_postcode(search_df, "full_address")
85
- # search_df["full_address"] == address_series
86
-
87
- # # Ensure index column
88
- # final_df = _ensure_index(full_df, key_col)
89
-
90
- # #print(final_df)
91
-
92
-
93
- # return final_df, key_col
94
-
95
  def prepare_search_address(
96
  search_df: pd.DataFrame,
97
  address_cols: list,
@@ -145,25 +102,7 @@ def _clean_columns(df, cols):
145
  df[cols] = df[cols].apply(clean_col)
146
 
147
  return df
148
-
149
- # def _clean_columns(df, cols):
150
- # # Cleaning logic
151
- # #print(df)
152
-
153
- # #if isinstance(df, pl.DataFrame):
154
- # # print("It's a Polars DataFrame")
155
-
156
- # def clean_col(col):
157
- # col = col.str.replace("nan", "")
158
- # col = col.apply(lambda x: re.sub(r'\s{2,}', ' ', str(x)), skip_nulls=False, return_dtype=str) # replace any spaces greater than one with one
159
- # return col.str.replace(",", " ").str.strip() # replace commas with a space
160
-
161
- # for col in cols:
162
- # df = df.with_columns(clean_col(df[col]).alias(col))
163
-
164
- # return df
165
-
166
-
167
  def _join_address(df, cols):
168
  # Joining logic
169
  full_address = df[cols].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
@@ -289,43 +228,6 @@ def prepare_ref_address(ref_df, ref_address_cols, new_join_col = ['UPRN'], stand
289
 
290
  return ref_df_cleaned
291
 
292
- # def prepare_ref_address(ref_df:pl.DataFrame, ref_address_cols, new_join_col = ['UPRN'], standard_cols = True):
293
-
294
- # if ('SaoText' in ref_df.columns) | ("Secondary_Name_LPI" in ref_df.columns):
295
- # standard_cols = True
296
- # else:
297
- # standard_cols = False
298
-
299
- # ref_address_cols_uprn = list(ref_address_cols) + new_join_col
300
- # ref_df_cleaned = ref_df[ref_address_cols_uprn].fill_null("")
301
-
302
- # # In on-prem LPI db street has been excluded, so put this back in
303
- # if ('Street' not in ref_df_cleaned.columns) & ('Address_LPI' in ref_df_cleaned.columns):
304
- # ref_df_cleaned = ref_df_cleaned.with_column(pl.col('Address_LPI').apply(lambda x: extract_street_name(x)).alias('Street'))
305
-
306
- # if ('Organisation' not in ref_df_cleaned.columns) & ('SaoText' in ref_df_cleaned.columns):
307
- # ref_df_cleaned = ref_df_cleaned.with_column(pl.lit("").alias('Organisation'))
308
-
309
- # #ref_df_cleaned['fulladdress'] =
310
-
311
- # if standard_cols:
312
- # pass
313
- # # I can not write the full address code here as it depends on your extract_street_name and create_full_address function implementations.
314
- # # However, you might need to convert string types to object type for full address creation which may require more than just a few lines of codes.
315
- # else:
316
- # pass
317
-
318
- # # I can not write the full address code here as it depends on your extract_street_name and create_full_address function implementations.
319
-
320
- # if 'Street' not in ref_df_cleaned.columns:
321
- # ref_df_cleaned = ref_df_cleaned.with_column(pl.col('fulladdress').apply(extract_street_name).alias("Street"))
322
-
323
- # # Add index column
324
- # ref_df_cleaned = ref_df_cleaned.with_column(pl.lit('').alias('ref_index'))
325
-
326
- # return ref_df_cleaned
327
-
328
-
329
  def extract_postcode(df, col:str) -> PandasSeries:
330
  '''
331
  Extract a postcode from a string column in a dataframe
@@ -335,7 +237,6 @@ def extract_postcode(df, col:str) -> PandasSeries:
335
 
336
  return postcode_series
337
 
338
-
339
  # Remove addresses with no numbers in at all - too high a risk of badly assigning an address
340
  def check_no_number_addresses(df, in_address_series) -> PandasSeries:
341
  '''
@@ -353,7 +254,6 @@ def check_no_number_addresses(df, in_address_series) -> PandasSeries:
353
 
354
  return df
355
 
356
-
357
  def remove_postcode(df, col:str) -> PandasSeries:
358
  '''
359
  Remove a postcode from a string column in a dataframe
 
49
 
50
  return search_df_out, key_field, address_cols, postcode_col
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def prepare_search_address(
53
  search_df: pd.DataFrame,
54
  address_cols: list,
 
102
  df[cols] = df[cols].apply(clean_col)
103
 
104
  return df
105
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def _join_address(df, cols):
107
  # Joining logic
108
  full_address = df[cols].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
 
228
 
229
  return ref_df_cleaned
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def extract_postcode(df, col:str) -> PandasSeries:
232
  '''
233
  Extract a postcode from a string column in a dataframe
 
237
 
238
  return postcode_series
239
 
 
240
  # Remove addresses with no numbers in at all - too high a risk of badly assigning an address
241
  def check_no_number_addresses(df, in_address_series) -> PandasSeries:
242
  '''
 
254
 
255
  return df
256
 
 
257
  def remove_postcode(df, col:str) -> PandasSeries:
258
  '''
259
  Remove a postcode from a string column in a dataframe