encrypted_credit_scoring / development.py
romanbredehoft-zama's picture
Remove unused imports
615cfe4
raw
history blame
2.91 kB
"""Train and compile the model."""
import shutil
import numpy
import pandas
import pickle
from settings import (
DEPLOYMENT_PATH,
DATA_PATH,
INPUT_SLICES,
PRE_PROCESSOR_APPLICANT_PATH,
PRE_PROCESSOR_BANK_PATH,
PRE_PROCESSOR_CREDIT_BUREAU_PATH,
APPLICANT_COLUMNS,
BANK_COLUMNS,
CREDIT_BUREAU_COLUMNS,
)
from utils.client_server_interface import MultiInputsFHEModelDev
from utils.model import MultiInputDecisionTreeClassifier
from utils.pre_processing import get_pre_processors
def get_multi_inputs(data):
"""Get inputs for all three parties from the input data, using fixed slices.
Args:
data (numpy.ndarray): The input data to consider.
Returns:
(Tuple[numpy.ndarray]): The inputs for all three parties.
"""
return (
data[:, INPUT_SLICES["applicant"]],
data[:, INPUT_SLICES["bank"]],
data[:, INPUT_SLICES["credit_bureau"]]
)
print("Load and pre-process the data")
# Load the data
data = pandas.read_csv(DATA_PATH, encoding="utf-8")
# Define input and target data
data_x = data.copy()
data_y = data_x.pop("Target").copy().to_frame()
# Get data from all parties
data_applicant = data_x[APPLICANT_COLUMNS].copy()
data_bank = data_x[BANK_COLUMNS].copy()
data_credit_bureau = data_x[CREDIT_BUREAU_COLUMNS].copy()
# Feature engineer the data
pre_processor_applicant, pre_processor_bank, pre_processor_credit_bureau = get_pre_processors()
preprocessed_data_applicant = pre_processor_applicant.fit_transform(data_applicant)
preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank)
preprocessed_data_credit_bureau = pre_processor_credit_bureau.fit_transform(data_credit_bureau)
preprocessed_data_x = numpy.concatenate((preprocessed_data_applicant, preprocessed_data_bank, preprocessed_data_credit_bureau), axis=1)
print("\nTrain and compile the model")
model = MultiInputDecisionTreeClassifier()
model, sklearn_model = model.fit_benchmark(preprocessed_data_x, data_y)
multi_inputs_train = get_multi_inputs(preprocessed_data_x)
model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
print("\nSave deployment files")
# Delete the deployment folder and its content if it already exists
if DEPLOYMENT_PATH.is_dir():
shutil.rmtree(DEPLOYMENT_PATH)
# Save files needed for deployment (and enable cross-platform deployment)
fhe_model_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
fhe_model_dev.save(via_mlir=True)
# Save pre-processors
with (
PRE_PROCESSOR_APPLICANT_PATH.open('wb') as file_applicant,
PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank,
PRE_PROCESSOR_CREDIT_BUREAU_PATH.open('wb') as file_credit_bureau,
):
pickle.dump(pre_processor_applicant, file_applicant)
pickle.dump(pre_processor_bank, file_bank)
pickle.dump(pre_processor_credit_bureau, file_credit_bureau)
print("\nDone !")