|
import os |
|
import pandas as pd |
|
import pickle |
|
import torch |
|
import zipfile |
|
from typing import List, Union, Type, Dict |
|
from pydantic import BaseModel |
|
|
|
from .pytorch_models import * |
|
|
|
PandasDataFrame = Type[pd.DataFrame] |
|
PandasSeries = Type[pd.Series] |
|
|
|
def get_or_create_env_var(var_name, default_value): |
|
|
|
value = os.environ.get(var_name) |
|
|
|
|
|
if value is None: |
|
os.environ[var_name] = default_value |
|
value = default_value |
|
|
|
return value |
|
|
|
|
|
env_var_name = 'GRADIO_OUTPUT_FOLDER' |
|
default_value = 'output/' |
|
|
|
output_folder = get_or_create_env_var(env_var_name, default_value) |
|
print(f'The value of {env_var_name} is {output_folder}') |
|
|
|
|
|
''' Fuzzywuzzy/Rapidfuzz scorer to use. Options are: ratio, partial_ratio, token_sort_ratio, partial_token_sort_ratio, |
|
token_set_ratio, partial_token_set_ratio, QRatio, UQRatio, WRatio (default), UWRatio |
|
details here: https://stackoverflow.com/questions/31806695/when-to-use-which-fuzz-function-to-compare-2-strings''' |
|
|
|
fuzzy_scorer_used = "token_set_ratio" |
|
|
|
fuzzy_match_limit = 85 |
|
fuzzy_search_addr_limit = 20 |
|
filter_to_lambeth_pcodes= True |
|
standardise = False |
|
|
|
if standardise == True: |
|
std = "_std" |
|
if standardise == False: |
|
std = "_not_std" |
|
|
|
dataset_name = "data" + std |
|
|
|
suffix_used = dataset_name + "_" + fuzzy_scorer_used |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "lstm" |
|
model_stub = "pytorch/lstm" |
|
model_version = "" |
|
file_step_suffix = "" |
|
data_sample_size = 476887 |
|
N_EPOCHS = 10 |
|
max_predict_len = 12000 |
|
|
|
word_to_index = {} |
|
cat_to_idx = {} |
|
vocab = [] |
|
device = "cpu" |
|
|
|
global labels_list |
|
labels_list = [] |
|
|
|
ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '..')) |
|
|
|
|
|
if output_folder == "output/": |
|
out_model_dir = ROOT_DIR |
|
print(out_model_dir) |
|
else: |
|
out_model_dir = output_folder[:-1] |
|
print(out_model_dir) |
|
|
|
model_dir_name = os.path.join(ROOT_DIR, "nnet_model" , model_stub , model_version) |
|
|
|
model_path = os.path.join(model_dir_name, "saved_model.zip") |
|
print("Model zip path: ", model_path) |
|
|
|
if os.path.exists(model_path): |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' |
|
device = "cpu" |
|
|
|
|
|
|
|
''' Load pre-trained model ''' |
|
|
|
with zipfile.ZipFile(model_path,"r") as zip_ref: |
|
zip_ref.extractall(out_model_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "pytorch" in model_stub: |
|
|
|
labels_list = [ |
|
'SaoText', |
|
'SaoStartNumber', |
|
'SaoStartSuffix', |
|
'SaoEndNumber', |
|
'SaoEndSuffix', |
|
'PaoText', |
|
'PaoStartNumber', |
|
'PaoStartSuffix', |
|
'PaoEndNumber', |
|
'PaoEndSuffix', |
|
'Street', |
|
'PostTown', |
|
'AdministrativeArea', |
|
'Postcode', |
|
'IGNORE' |
|
] |
|
|
|
|
|
if (model_type == "transformer") | (model_type == "gru") | (model_type == "lstm") : |
|
|
|
with open(out_model_dir + "/vocab.txt", "r") as f: |
|
vocab = eval(f.read()) |
|
with open(out_model_dir + "/word_to_index.txt", "r") as f: |
|
word_to_index = eval(f.read()) |
|
with open(out_model_dir + "/cat_to_idx.txt", "r") as f: |
|
cat_to_idx = eval(f.read()) |
|
|
|
VOCAB_SIZE = len(word_to_index) |
|
OUTPUT_DIM = len(cat_to_idx) + 1 |
|
EMBEDDING_DIM = 48 |
|
DROPOUT = 0.1 |
|
PAD_TOKEN = 0 |
|
|
|
|
|
if model_type == "transformer": |
|
NHEAD = 4 |
|
NUM_ENCODER_LAYERS = 1 |
|
|
|
exported_model = TransformerClassifier(VOCAB_SIZE, EMBEDDING_DIM, NHEAD, NUM_ENCODER_LAYERS, OUTPUT_DIM, DROPOUT, PAD_TOKEN) |
|
|
|
elif model_type == "gru": |
|
N_LAYERS = 3 |
|
HIDDEN_DIM = 128 |
|
exported_model = TextClassifier(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT, PAD_TOKEN) |
|
|
|
elif model_type == "lstm": |
|
N_LAYERS = 3 |
|
HIDDEN_DIM = 128 |
|
|
|
exported_model = LSTMTextClassifier(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT, PAD_TOKEN) |
|
|
|
|
|
out_model_file_name = "output_model_" + str(data_sample_size) +\ |
|
"_" + str(N_EPOCHS) + "_" + model_type + ".pth" |
|
|
|
out_model_path = os.path.join(out_model_dir, out_model_file_name) |
|
print("Model location: ", out_model_path) |
|
exported_model.load_state_dict(torch.load(out_model_path, map_location=torch.device('cpu'), weights_only=False)) |
|
exported_model.eval() |
|
|
|
device='cpu' |
|
|
|
exported_model.to(device) |
|
|
|
|
|
else: |
|
exported_model = [] |
|
|
|
|
|
|
|
else: exported_model = [] |
|
|
|
|
|
|
|
batch_size = 10000 |
|
ref_batch_size = 20000 |
|
|
|
|
|
|
|
''' https://recordlinkage.readthedocs.io/en/latest/ref_df-compare.html#recordlinkage.compare.String |
|
The Python Record Linkage Toolkit uses the jellyfish package for the Jaro, Jaro-Winkler, Levenshtein and Damerau- Levenshtein algorithms. |
|
Options are [‘jaro’, ‘jarowinkler’, ‘levenshtein’, ‘damerau_levenshtein’, ‘qgram’, ‘cosine’, ‘smith_waterman’, ‘lcs’] |
|
|
|
Comparison of some of the Jellyfish string comparison methods: https://manpages.debian.org/testing/python-jellyfish-doc/jellyfish.3.en.html ''' |
|
|
|
fuzzy_method = "jarowinkler" |
|
|
|
|
|
score_cut_off = 98.7 |
|
|
|
score_cut_off_nnet_street = 99.5 |
|
|
|
no_number_fuzzy_match_limit = 100 |
|
|
|
|
|
ref_address_cols = ["Organisation", "SaoStartNumber", "SaoStartSuffix", "SaoEndNumber", "SaoEndSuffix", |
|
"SaoText", "PaoStartNumber", "PaoStartSuffix", "PaoEndNumber", |
|
"PaoEndSuffix", "PaoText", "Street", "PostTown", "Postcode"] |
|
|
|
|
|
matching_variables = ref_address_cols |
|
text_columns = ["Organisation", "PaoText", "Street", "PostTown", "Postcode"] |
|
|
|
|
|
Organisation_weight = 0.1 |
|
PaoStartNumber_weight = 2 |
|
SaoStartNumber_weight = 2 |
|
Street_weight = 2 |
|
PostTown_weight = 0 |
|
Postcode_weight = 0.5 |
|
AdministrativeArea_weight = 0 |
|
|
|
|
|
weight_vals = [1] * len(ref_address_cols) |
|
weight_keys = ref_address_cols |
|
weights = {weight_keys[i]: weight_vals[i] for i in range(len(weight_keys))} |
|
|
|
|
|
|
|
|
|
weights["Organisation"] = Organisation_weight |
|
weights["SaoStartNumber"] = SaoStartNumber_weight |
|
weights["PaoStartNumber"] = PaoStartNumber_weight |
|
weights["Street"] = Street_weight |
|
weights["PostTown"] = PostTown_weight |
|
weights["Postcode"] = Postcode_weight |
|
|
|
|
|
|
|
|
|
class MatcherClass(BaseModel): |
|
|
|
fuzzy_scorer_used: str |
|
fuzzy_match_limit: int |
|
fuzzy_search_addr_limit: int |
|
filter_to_lambeth_pcodes: bool |
|
standardise: bool |
|
suffix_used: str |
|
|
|
|
|
matching_variables: List[str] |
|
model_dir_name: str |
|
file_step_suffix: str |
|
exported_model: List |
|
|
|
fuzzy_method: str |
|
score_cut_off: float |
|
text_columns: List[str] |
|
weights: dict |
|
model_type: str |
|
labels_list: List[str] |
|
|
|
|
|
|
|
word_to_index: dict |
|
cat_to_idx: dict |
|
device: str |
|
vocab: List[str] |
|
|
|
|
|
file_name: str |
|
ref_name: str |
|
search_df: pd.DataFrame |
|
excluded_df: pd.DataFrame |
|
pre_filter_search_df: pd.DataFrame |
|
search_address_cols: List[str] |
|
search_postcode_col: List[str] |
|
search_df_key_field: str |
|
ref_df: pd.DataFrame |
|
ref_pre_filter: pd.DataFrame |
|
ref_address_cols: List[str] |
|
new_join_col: List[str] |
|
|
|
existing_match_cols: List[str] |
|
standard_llpg_format: List[str] |
|
|
|
|
|
match_results_output: pd.DataFrame |
|
predict_df_nnet: pd.DataFrame |
|
|
|
|
|
compare_all_candidates: List[str] |
|
diag_shortlist: List[str] |
|
diag_best_match: List[str] |
|
|
|
results_on_orig_df: pd.DataFrame |
|
|
|
summary: str |
|
output_summary: str |
|
match_outputs_name: str |
|
results_orig_df_name: str |
|
|
|
search_df_after_stand: pd.DataFrame |
|
ref_df_after_stand: pd.DataFrame |
|
search_df_after_full_stand: pd.DataFrame |
|
ref_df_after_full_stand: pd.DataFrame |
|
|
|
search_df_after_stand_series: pd.Series |
|
ref_df_after_stand_series: pd.Series |
|
search_df_after_stand_series_full_stand: pd.Series |
|
ref_df_after_stand_series_full_stand: pd.Series |
|
|
|
|
|
|
|
abort_flag: bool |
|
|
|
|
|
class Config: |
|
|
|
arbitrary_types_allowed = True |
|
extra = 'allow' |
|
|
|
protected_namespaces = () |
|
|
|
|
|
|
|
|
|
InitMatch = MatcherClass( |
|
|
|
|
|
fuzzy_scorer_used = fuzzy_scorer_used, |
|
fuzzy_match_limit = fuzzy_match_limit, |
|
fuzzy_search_addr_limit = fuzzy_search_addr_limit, |
|
filter_to_lambeth_pcodes = filter_to_lambeth_pcodes, |
|
standardise = standardise, |
|
suffix_used = suffix_used, |
|
|
|
|
|
matching_variables = matching_variables, |
|
model_dir_name = model_dir_name, |
|
file_step_suffix = file_step_suffix, |
|
|
|
exported_model = [exported_model], |
|
|
|
fuzzy_method = fuzzy_method, |
|
score_cut_off = score_cut_off, |
|
text_columns = text_columns, |
|
weights = weights, |
|
model_type = model_type, |
|
labels_list = labels_list, |
|
|
|
|
|
|
|
|
|
word_to_index = word_to_index, |
|
cat_to_idx = cat_to_idx, |
|
device = device, |
|
vocab = vocab, |
|
|
|
|
|
file_name = '', |
|
ref_name = '', |
|
df_name = '', |
|
search_df = pd.DataFrame(), |
|
excluded_df = pd.DataFrame(), |
|
pre_filter_search_df = pd.DataFrame(), |
|
search_df_not_matched = pd.DataFrame(), |
|
search_df_cleaned = pd.DataFrame(), |
|
search_address_cols = [], |
|
search_postcode_col = [], |
|
search_df_key_field = 'index', |
|
|
|
ref_df = pd.DataFrame(), |
|
ref_df_cleaned = pd.DataFrame(), |
|
ref_pre_filter = pd.DataFrame(), |
|
ref_address_cols = [], |
|
new_join_col = [], |
|
|
|
existing_match_cols = [], |
|
standard_llpg_format = [], |
|
|
|
|
|
|
|
match_results_output = pd.DataFrame(), |
|
predict_df_nnet = pd.DataFrame(), |
|
|
|
|
|
compare_all_candidates = [], |
|
diag_shortlist = [], |
|
diag_best_match = [], |
|
|
|
results_on_orig_df = pd.DataFrame(), |
|
summary = "", |
|
output_summary = "", |
|
|
|
match_outputs_name = "", |
|
results_orig_df_name = "", |
|
|
|
|
|
search_df_after_stand = pd.DataFrame(), |
|
ref_df_after_stand = pd.DataFrame(), |
|
search_df_after_stand_series = pd.Series(), |
|
ref_df_after_stand_series = pd.Series(), |
|
|
|
search_df_after_full_stand = pd.DataFrame(), |
|
ref_df_after_full_stand = pd.DataFrame(), |
|
search_df_after_stand_series_full_stand = pd.Series(), |
|
ref_df_after_stand_series_full_stand = pd.Series(), |
|
|
|
|
|
abort_flag = False |
|
) |