Spaces:
Sleeping
Sleeping
| """Backend functions used in the app.""" | |
| import os | |
| import shutil | |
| import gradio as gr | |
| import numpy | |
| import requests | |
| import pickle | |
| import pandas | |
| from itertools import chain | |
| from settings import ( | |
| SERVER_URL, | |
| FHE_KEYS, | |
| CLIENT_FILES, | |
| SERVER_FILES, | |
| DEPLOYMENT_PATH, | |
| PROCESSED_INPUT_SHAPE, | |
| INPUT_INDEXES, | |
| INPUT_SLICES, | |
| PRE_PROCESSOR_APPLICANT_PATH, | |
| PRE_PROCESSOR_BANK_PATH, | |
| PRE_PROCESSOR_CREDIT_BUREAU_PATH, | |
| CLIENT_TYPES, | |
| APPLICANT_COLUMNS, | |
| BANK_COLUMNS, | |
| CREDIT_BUREAU_COLUMNS, | |
| YEARS_EMPLOYED_BINS, | |
| YEARS_EMPLOYED_BIN_NAME_TO_INDEX, | |
| ) | |
| from utils.client_server_interface import MultiInputsFHEModelClient | |
| # Define the messages associated to the predictions | |
| APPROVED_MESSAGE = "信用卡申请可以批准 ✅" | |
| DENIED_MESSAGE = "信用卡申请可能被拒绝 ❌" | |
| # Load pre-processor instances | |
| with ( | |
| PRE_PROCESSOR_APPLICANT_PATH.open('rb') as file_applicant, | |
| PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank, | |
| PRE_PROCESSOR_CREDIT_BUREAU_PATH.open('rb') as file_credit_bureau, | |
| ): | |
| PRE_PROCESSOR_APPLICANT = pickle.load(file_applicant) | |
| PRE_PROCESSOR_BANK = pickle.load(file_bank) | |
| PRE_PROCESSOR_CREDIT_BUREAU = pickle.load(file_credit_bureau) | |
| def shorten_bytes_object(bytes_object, limit=500): | |
| """Shorten the input bytes object to a given length. | |
| Encrypted data is too large for displaying it in the browser using Gradio. This function | |
| provides a shorten representation of it. | |
| Args: | |
| bytes_object (bytes): The input to shorten | |
| limit (int): The length to consider. Default to 500. | |
| Returns: | |
| str: Hexadecimal string shorten representation of the input byte object. | |
| """ | |
| # Define a shift for better display | |
| shift = 100 | |
| return bytes_object[shift : limit + shift].hex() | |
| def clean_temporary_files(n_keys=20): | |
| """Clean older keys and encrypted files. | |
| A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this | |
| limit is reached, the oldest files are deleted. | |
| Args: | |
| n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. | |
| """ | |
| # Get the oldest key files in the key directory | |
| key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) | |
| # If more than n_keys keys are found, remove the oldest | |
| client_ids = [] | |
| if len(key_dirs) > n_keys: | |
| n_keys_to_delete = len(key_dirs) - n_keys | |
| for key_dir in key_dirs[:n_keys_to_delete]: | |
| client_ids.append(key_dir.name) | |
| shutil.rmtree(key_dir) | |
| # Delete all files related to the IDs whose keys were deleted | |
| for directory in chain(CLIENT_FILES.iterdir(), SERVER_FILES.iterdir()): | |
| for client_id in client_ids: | |
| if client_id in directory.name: | |
| shutil.rmtree(directory) | |
| def _get_client(client_id): | |
| """Get the client instance. | |
| Args: | |
| client_id (int): The client ID to consider. | |
| Returns: | |
| FHEModelClient: The client instance. | |
| """ | |
| key_dir = FHE_KEYS / f"{client_id}" | |
| return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES)) | |
| def _get_client_file_path(name, client_id, client_type=None): | |
| """Get the file path for the client. | |
| Args: | |
| name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or | |
| 'encrypted_outputs'). | |
| client_id (int): The client ID to consider. | |
| client_type (Optional[str]): The type of client to consider (either 'applicant', 'bank', | |
| 'credit_bureau' or None). Default to None, which is used for evaluation key and output. | |
| Returns: | |
| pathlib.Path: The file path. | |
| """ | |
| client_type_suffix = "" | |
| if client_type is not None: | |
| client_type_suffix = f"_{client_type}" | |
| dir_path = CLIENT_FILES / f"{client_id}" | |
| dir_path.mkdir(exist_ok=True) | |
| return dir_path / f"{name}{client_type_suffix}" | |
| def _send_to_server(client_id, client_type, file_name): | |
| """Send the encrypted inputs or the evaluation key to the server. | |
| Args: | |
| client_id (int): The client ID to consider. | |
| client_type (Optional[str]): The type of client to consider (either 'applicant', 'bank', | |
| 'credit_bureau' or None). | |
| file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs'). | |
| """ | |
| # Get the paths to the encrypted inputs | |
| encrypted_file_path = _get_client_file_path(file_name, client_id, client_type) | |
| # Define the data and files to post | |
| data = { | |
| "client_id": client_id, | |
| "client_type": client_type, | |
| "file_name": file_name, | |
| } | |
| files = [ | |
| ("files", open(encrypted_file_path, "rb")), | |
| ] | |
| # Send the encrypted inputs or evaluation key to the server | |
| url = SERVER_URL + "send_file" | |
| with requests.post( | |
| url=url, | |
| data=data, | |
| files=files, | |
| ) as response: | |
| return response.ok | |
| def keygen_send(): | |
| """Generate the private and evaluation key, and send the evaluation key to the server. | |
| Returns: | |
| client_id (str): The current client ID to consider. | |
| """ | |
| # Clean temporary files | |
| clean_temporary_files() | |
| # Create an ID for the current client to consider | |
| client_id = numpy.random.randint(0, 2**32) | |
| # Retrieve the client instance | |
| client = _get_client(client_id) | |
| # Generate the private and evaluation keys | |
| client.generate_private_and_evaluation_keys(force=True) | |
| # Retrieve the serialized evaluation key | |
| evaluation_key = client.get_serialized_evaluation_keys() | |
| file_name = "evaluation_key" | |
| # Save evaluation key as bytes in a file as it is too large to pass through regular Gradio | |
| # buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
| evaluation_key_path = _get_client_file_path(file_name, client_id) | |
| with evaluation_key_path.open("wb") as evaluation_key_file: | |
| evaluation_key_file.write(evaluation_key) | |
| # Send the evaluation key to the server | |
| _send_to_server(client_id, None, file_name) | |
| # Create a truncated version of the evaluation key for display | |
| evaluation_key_short = shorten_bytes_object(evaluation_key) | |
| return client_id, evaluation_key_short, gr.update(value="密钥已生成并且计算密钥已发送✅") | |
| def _encrypt_send(client_id, inputs, client_type): | |
| """Encrypt the given inputs for a specific client and send it to the server. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| inputs (numpy.ndarray): The inputs to encrypt. | |
| client_type (str): The type of client to consider (either 'applicant', 'bank' or | |
| 'credit_bureau'). | |
| Returns: | |
| encrypted_inputs_short (str): A short representation of the encrypted input to send in hex. | |
| """ | |
| if client_id == "": | |
| raise gr.Error("Please generate the keys first.") | |
| # Retrieve the client instance | |
| client = _get_client(client_id) | |
| # Quantize, encrypt and serialize the inputs | |
| encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( | |
| inputs, | |
| input_index=INPUT_INDEXES[client_type], | |
| processed_input_shape=PROCESSED_INPUT_SHAPE, | |
| input_slice=INPUT_SLICES[client_type], | |
| ) | |
| file_name = "encrypted_inputs" | |
| # Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio | |
| # buttons, https://github.com/gradio-app/gradio/issues/1877 | |
| encrypted_inputs_path = _get_client_file_path(file_name, client_id, client_type) | |
| with encrypted_inputs_path.open("wb") as encrypted_inputs_file: | |
| encrypted_inputs_file.write(encrypted_inputs) | |
| # Create a truncated version of the encrypted inputs for display | |
| encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) | |
| _send_to_server(client_id, client_type, file_name) | |
| return encrypted_inputs_short, gr.update(value="输入被加密发送到服务器。 ✅") | |
| def pre_process_encrypt_send_applicant(client_id, *inputs): | |
| """Pre-process, encrypt and send the applicant inputs for a specific client to the server. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
| Returns: | |
| (str): A short representation of the encrypted input to send in hex. | |
| """ | |
| bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \ | |
| family_status, occupation_type, housing_type = inputs | |
| # Retrieve boolean values | |
| own_car = "Car" in bool_inputs | |
| own_property = "Property" in bool_inputs | |
| mobile_phone = "Mobile phone" in bool_inputs | |
| applicant_inputs = pandas.DataFrame({ | |
| "Own_car": [own_car], | |
| "Own_property": [own_property], | |
| "Mobile_phone": [mobile_phone], | |
| "Num_children": [num_children], | |
| "Household_size": [household_size], | |
| "Total_income": [total_income], | |
| "Age": [age], | |
| "Income_type": [income_type], | |
| "Education_type": [education_type], | |
| "Family_status": [family_status], | |
| "Occupation_type": [occupation_type], | |
| "Housing_type": [housing_type], | |
| }) | |
| applicant_inputs = applicant_inputs.reindex(APPLICANT_COLUMNS, axis=1) | |
| preprocessed_applicant_inputs = PRE_PROCESSOR_APPLICANT.transform(applicant_inputs) | |
| return _encrypt_send(client_id, preprocessed_applicant_inputs, "applicant") | |
| def pre_process_encrypt_send_bank(client_id, *inputs): | |
| """Pre-process, encrypt and send the bank inputs for a specific client to the server. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
| Returns: | |
| (str): A short representation of the encrypted input to send in hex. | |
| """ | |
| account_age = inputs[0] | |
| bank_inputs = pandas.DataFrame({ | |
| "Account_age": [account_age], | |
| }) | |
| bank_inputs = bank_inputs.reindex(BANK_COLUMNS, axis=1) | |
| preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs) | |
| return _encrypt_send(client_id, preprocessed_bank_inputs, "bank") | |
| def pre_process_encrypt_send_credit_bureau(client_id, *inputs): | |
| """Pre-process, encrypt and send the credit bureau inputs for a specific client to the server. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| *inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
| Returns: | |
| (str): A short representation of the encrypted input to send in hex. | |
| """ | |
| years_employed_bin, employed = inputs | |
| years_employed = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed_bin] | |
| is_employed = employed == "Yes" | |
| credit_bureau_inputs = pandas.DataFrame({ | |
| "Years_employed": [years_employed], | |
| "Employed": [is_employed], | |
| }) | |
| credit_bureau_inputs = credit_bureau_inputs.reindex(CREDIT_BUREAU_COLUMNS, axis=1) | |
| preprocessed_credit_bureau_inputs = PRE_PROCESSOR_CREDIT_BUREAU.transform(credit_bureau_inputs) | |
| return _encrypt_send(client_id, preprocessed_credit_bureau_inputs, "credit_bureau") | |
| def run_fhe(client_id): | |
| """Run the model on the encrypted inputs previously sent using FHE. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| """ | |
| if client_id == "": | |
| raise gr.Error("Please generate the keys first.") | |
| data = { | |
| "client_id": client_id, | |
| } | |
| # Trigger the FHE execution on the encrypted inputs previously sent | |
| url = SERVER_URL + "run_fhe" | |
| with requests.post( | |
| url=url, | |
| data=data, | |
| ) as response: | |
| if response.ok: | |
| return response.json(), gr.update(value="FHE计算已完成。 ✅") | |
| else: | |
| raise gr.Error("三个单位的输入都要先发送到服务器。") | |
| def get_output_and_decrypt(client_id): | |
| """Retrieve the encrypted output. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| Returns: | |
| (Tuple[str, bytes]): The output message based on the decrypted prediction as well as | |
| a byte short representation of the encrypted output. | |
| """ | |
| if client_id == "": | |
| raise gr.Error("Please generate the keys first.") | |
| data = { | |
| "client_id": client_id, | |
| } | |
| # Retrieve the encrypted output | |
| url = SERVER_URL + "get_output" | |
| with requests.post( | |
| url=url, | |
| data=data, | |
| ) as response: | |
| if response.ok: | |
| encrypted_output_proba = response.content | |
| # Create a truncated version of the encrypted inputs for display | |
| encrypted_output_short = shorten_bytes_object(encrypted_output_proba) | |
| # Retrieve the client API | |
| client = _get_client(client_id) | |
| # Deserialize, decrypt and post-process the encrypted output | |
| output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba) | |
| # Determine the predicted class | |
| output = numpy.argmax(output_proba, axis=1).squeeze() | |
| return ( | |
| APPROVED_MESSAGE if output == 1 else DENIED_MESSAGE, | |
| encrypted_output_short, | |
| gr.update(value="已从服务器接收加密输出。 ✅"), | |
| ) | |
| else: | |
| raise gr.Error("请先运行FHE计算并等待其完成。") | |
| def explain_encrypt_run_decrypt(client_id, prediction_output, *inputs): | |
| """Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output. | |
| Args: | |
| client_id (str): The current client ID to consider. | |
| prediction_output (str): The initial prediction output. This parameter is only used to | |
| throw an error in case the prediction was positive. | |
| *inputs (Tuple[numpy.ndarray]): The inputs to consider. | |
| Returns: | |
| (str): A message indicating the number of additional years of employment that could be | |
| required in order to increase the chance of credit card approval. | |
| """ | |
| if "approved" in prediction_output: | |
| raise gr.Error( | |
| "仅当信用卡可能被拒绝时才能解释该预测。" | |
| ) | |
| button_update = gr.update(value="预测已经得到解释。 ✅") | |
| # Retrieve the credit bureau inputs | |
| years_employed, employed = inputs | |
| # Years_employed is divided into several ordered bins. Here, we retrieve the index representing | |
| # the bin from the input | |
| bin_index = YEARS_EMPLOYED_BIN_NAME_TO_INDEX[years_employed] | |
| # If the bin is not the last (representing the most years of employment), we run the model in | |
| # FHE for each bins "older" or equal to the given bin, in order. Then, we retrieve the first | |
| # bin that changes the model's prediction to "approval" and display it to the applicant. | |
| if bin_index != len(YEARS_EMPLOYED_BINS) - 1: | |
| # Loop over the bins starting with "older" or equal to the given bin | |
| for years_employed_bin in YEARS_EMPLOYED_BINS[bin_index:]: | |
| # Send the new encrypted input | |
| pre_process_encrypt_send_credit_bureau(client_id, years_employed_bin, employed) | |
| # Run the model in FHE | |
| run_fhe(client_id) | |
| # Retrieve the new prediction | |
| output_prediction = get_output_and_decrypt(client_id) | |
| # If the bin made the model predict an approval, share it to the applicant | |
| if "approved" in output_prediction[0]: | |
| # If the approval was made using the given input, that means the applicant most | |
| # likely tried the bin suggested in a previous explainability run. In that case, we | |
| # confirm that the credit card is likely to be approved | |
| if years_employed_bin == years_employed: | |
| return APPROVED_MESSAGE, button_update | |
| # Else, that means the applicant is looking for some explainability. We therefore | |
| # suggest to try the obtained bin | |
| return ( | |
| DENIED_MESSAGE + f" However, having at least {years_employed_bin} years of " | |
| "employment would increase your chance of having your credit card approved." | |
| ), button_update | |
| # In case no bins made the model predict an approval, explain why | |
| return ( | |
| DENIED_MESSAGE + " Unfortunately, increasing the number of years of employment up to " | |
| f"{YEARS_EMPLOYED_BINS[-1]} years does not seem to be enough to get an approval based " | |
| "on the given inputs. Other inputs like the income or the account's age might have " | |
| "bigger impact in this particular case." | |
| ), button_update | |
| # In case the applicant tried the "oldest" bin (but still got denied), explain why | |
| return ( | |
| DENIED_MESSAGE + " Unfortunately, you already have the maximum amount of years of " | |
| f"employment ({years_employed} years). Other inputs like the income or the account's age " | |
| "might have a bigger impact in this particular case." | |
| ), button_update | |