"""Train and compile the model.""" import shutil import numpy import pandas import pickle from settings import ( APPROVAL_DEPLOYMENT_PATH, EXPLAIN_DEPLOYMENT_PATH, DATA_PATH, APPROVAL_INPUT_SLICES, EXPLAIN_INPUT_SLICES, PRE_PROCESSOR_USER_PATH, PRE_PROCESSOR_BANK_PATH, PRE_PROCESSOR_THIRD_PARTY_PATH, USER_COLUMNS, BANK_COLUMNS, APPROVAL_THIRD_PARTY_COLUMNS, EXPLAIN_THIRD_PARTY_COLUMNS, ) from utils.client_server_interface import MultiInputsFHEModelDev from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor from utils.pre_processing import get_pre_processors def get_multi_inputs(data, is_approval): """Get inputs for all three parties from the input data, using fixed slices. Args: data (numpy.ndarray): The input data to consider. is_approval (bool): If the data should be used for the 'approval' model (else, otherwise for the 'explain' model). Returns: (Tuple[numpy.ndarray]): The inputs for all three parties. """ if is_approval: return ( data[:, APPROVAL_INPUT_SLICES["user"]], data[:, APPROVAL_INPUT_SLICES["bank"]], data[:, APPROVAL_INPUT_SLICES["third_party"]] ) return ( data[:, EXPLAIN_INPUT_SLICES["user"]], data[:, EXPLAIN_INPUT_SLICES["bank"]], data[:, EXPLAIN_INPUT_SLICES["third_party"]] ) print("Load and pre-process the data") # Load the data data = pandas.read_csv(DATA_PATH, encoding="utf-8") # Define input and target data data_x = data.copy() data_y = data_x.pop("Target").copy().to_frame() # Get data from all parties data_user = data_x[USER_COLUMNS].copy() data_bank = data_x[BANK_COLUMNS].copy() data_third_party = data_x[APPROVAL_THIRD_PARTY_COLUMNS].copy() # Feature engineer the data pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors() preprocessed_data_user = pre_processor_user.fit_transform(data_user) preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank) preprocessed_data_third_party = pre_processor_third_party.fit_transform(data_third_party) preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1) print("\nTrain and compile the model") model_approval = MultiInputDecisionTreeClassifier() model_approval, sklearn_model_approval = model_approval.fit_benchmark(preprocessed_data_x, data_y) multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=True) model_approval.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"]) print("\nSave deployment files") # Delete the deployment folder and its content if it already exists if APPROVAL_DEPLOYMENT_PATH.is_dir(): shutil.rmtree(APPROVAL_DEPLOYMENT_PATH) # Save files needed for deployment (and enable cross-platform deployment) fhe_model_dev_approval = MultiInputsFHEModelDev(APPROVAL_DEPLOYMENT_PATH, model_approval) fhe_model_dev_approval.save(via_mlir=True) # Save pre-processors with ( PRE_PROCESSOR_USER_PATH.open('wb') as file_user, PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank, PRE_PROCESSOR_THIRD_PARTY_PATH.open('wb') as file_third_party, ): pickle.dump(pre_processor_user, file_user) pickle.dump(pre_processor_bank, file_bank) pickle.dump(pre_processor_third_party, file_third_party) print("\nLoad, train, compile and save files for the 'explain' model") # Define input and target data data_x = data.copy() data_y = data_x.pop("Years_employed").copy().to_frame() target_values = data_x.pop("Target").copy() # Get all data points whose target value is True (credit card has been approved) approved_mask = target_values == 1 data_x_approved = data_x[approved_mask] data_y_approved = data_y[approved_mask] # Get data from all parties data_user = data_x_approved[USER_COLUMNS].copy() data_bank = data_x_approved[BANK_COLUMNS].copy() data_third_party = data_x_approved[EXPLAIN_THIRD_PARTY_COLUMNS].copy() preprocessed_data_user = pre_processor_user.transform(data_user) preprocessed_data_bank = pre_processor_bank.transform(data_bank) preprocessed_data_third_party = data_third_party.to_numpy() preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1) model_explain = MultiInputDecisionTreeRegressor() model_explain, sklearn_model_explain = model_explain.fit_benchmark(preprocessed_data_x, data_y_approved) multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=False) model_explain.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"]) # Delete the deployment folder and its content if it already exists if EXPLAIN_DEPLOYMENT_PATH.is_dir(): shutil.rmtree(EXPLAIN_DEPLOYMENT_PATH) # Save files needed for deployment (and enable cross-platform deployment) fhe_model_dev_explain = MultiInputsFHEModelDev(EXPLAIN_DEPLOYMENT_PATH, model_explain) fhe_model_dev_explain.save(via_mlir=True) print("\nDone !")