romanbredehoft-zama's picture
Merge step 5 & 6
1ad0c1c
"""Backend functions used in the app."""
import os
import shutil
import gradio as gr
import numpy
import requests
import pickle
import pandas
from itertools import chain
from settings import (
SERVER_URL,
FHE_KEYS,
CLIENT_FILES,
SERVER_FILES,
APPROVAL_DEPLOYMENT_PATH,
EXPLAIN_DEPLOYMENT_PATH,
APPROVAL_PROCESSED_INPUT_SHAPE,
EXPLAIN_PROCESSED_INPUT_SHAPE,
INPUT_INDEXES,
APPROVAL_INPUT_SLICES,
EXPLAIN_INPUT_SLICES,
PRE_PROCESSOR_USER_PATH,
PRE_PROCESSOR_BANK_PATH,
PRE_PROCESSOR_THIRD_PARTY_PATH,
CLIENT_TYPES,
USER_COLUMNS,
BANK_COLUMNS,
APPROVAL_THIRD_PARTY_COLUMNS,
)
from utils.client_server_interface import MultiInputsFHEModelClient, MultiInputsFHEModelServer
# Load the server used for explaining the prediction
EXPLAIN_FHE_SERVER = MultiInputsFHEModelServer(EXPLAIN_DEPLOYMENT_PATH)
# Load pre-processor instances
with (
PRE_PROCESSOR_USER_PATH.open('rb') as file_user,
PRE_PROCESSOR_BANK_PATH.open('rb') as file_bank,
PRE_PROCESSOR_THIRD_PARTY_PATH.open('rb') as file_third_party,
):
PRE_PROCESSOR_USER = pickle.load(file_user)
PRE_PROCESSOR_BANK = pickle.load(file_bank)
PRE_PROCESSOR_THIRD_PARTY = pickle.load(file_third_party)
def shorten_bytes_object(bytes_object, limit=500):
"""Shorten the input bytes object to a given length.
Encrypted data is too large for displaying it in the browser using Gradio. This function
provides a shorten representation of it.
Args:
bytes_object (bytes): The input to shorten
limit (int): The length to consider. Default to 500.
Returns:
str: Hexadecimal string shorten representation of the input byte object.
"""
# Define a shift for better display
shift = 100
return bytes_object[shift : limit + shift].hex()
def clean_temporary_files(n_keys=20):
"""Clean older keys and encrypted files.
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
limit is reached, the oldest files are deleted.
Args:
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
"""
# Get the oldest key files in the key directory
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)
# If more than n_keys keys are found, remove the oldest
client_ids = []
if len(key_dirs) > n_keys:
n_keys_to_delete = len(key_dirs) - n_keys
for key_dir in key_dirs[:n_keys_to_delete]:
client_ids.append(key_dir.name)
shutil.rmtree(key_dir)
# Delete all files related to the IDs whose keys were deleted
for directory in chain(CLIENT_FILES.iterdir(), SERVER_FILES.iterdir()):
for client_id in client_ids:
if client_id in directory.name:
shutil.rmtree(directory)
def _get_client(client_id, is_approval=True):
"""Get the client instance.
Args:
client_id (int): The client ID to consider.
is_approval (bool): If client is representing the 'approval' model (else, it is
representing the 'explain' model). Default to True.
Returns:
FHEModelClient: The client instance.
"""
key_suffix = "approval" if is_approval else "explain"
key_dir = FHE_KEYS / f"{client_id}_{key_suffix}"
client_dir = APPROVAL_DEPLOYMENT_PATH if is_approval else EXPLAIN_DEPLOYMENT_PATH
return MultiInputsFHEModelClient(client_dir, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
def _get_client_file_path(name, client_id, client_type=None):
"""Get the file path for the client.
Args:
name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or
'encrypted_outputs').
client_id (int): The client ID to consider.
client_type (Optional[str]): The type of user to consider (either 'user', 'bank',
'third_party' or None). Default to None, which is used for evaluation key and output.
Returns:
pathlib.Path: The file path.
"""
client_type_suffix = ""
if client_type is not None:
client_type_suffix = f"_{client_type}"
dir_path = CLIENT_FILES / f"{client_id}"
dir_path.mkdir(exist_ok=True)
return dir_path / f"{name}{client_type_suffix}"
def _send_to_server(client_id, client_type, file_name):
"""Send the encrypted inputs or the evaluation key to the server.
Args:
client_id (int): The client ID to consider.
client_type (Optional[str]): The type of client to consider (either 'user', 'bank',
'third_party' or None).
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs').
"""
# Get the paths to the encrypted inputs
encrypted_file_path = _get_client_file_path(file_name, client_id, client_type)
# Define the data and files to post
data = {
"client_id": client_id,
"client_type": client_type,
"file_name": file_name,
}
files = [
("files", open(encrypted_file_path, "rb")),
]
# Send the encrypted inputs or evaluation key to the server
url = SERVER_URL + "send_file"
with requests.post(
url=url,
data=data,
files=files,
) as response:
return response.ok
def keygen_send():
"""Generate the private and evaluation key, and send the evaluation key to the server.
Returns:
client_id (str): The current client ID to consider.
"""
# Clean temporary files
clean_temporary_files()
# Create an ID for the current client to consider
client_id = numpy.random.randint(0, 2**32)
# Retrieve the client instance
client = _get_client(client_id)
# Generate the private and evaluation keys
client.generate_private_and_evaluation_keys(force=True)
# Retrieve the serialized evaluation key
evaluation_key = client.get_serialized_evaluation_keys()
file_name = "evaluation_key"
# Save evaluation key as bytes in a file as it is too large to pass through regular Gradio
# buttons (see https://github.com/gradio-app/gradio/issues/1877)
evaluation_key_path = _get_client_file_path(file_name, client_id)
with evaluation_key_path.open("wb") as evaluation_key_file:
evaluation_key_file.write(evaluation_key)
# Send the evaluation key to the server
_send_to_server(client_id, None, file_name)
# Create a truncated version of the evaluation key for display
evaluation_key_short = shorten_bytes_object(evaluation_key)
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent βœ…")
def _encrypt_send(client_id, inputs, client_type, app_mode=True):
"""Encrypt the given inputs for a specific client and send it to the server.
Args:
client_id (str): The current client ID to consider.
inputs (numpy.ndarray): The inputs to encrypt.
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
Returns:
encrypted_inputs_short (str): A short representation of the encrypted input to send in hex.
"""
if client_id == "":
raise gr.Error("Please generate the keys first.")
# Retrieve the client instance
client = _get_client(client_id)
# Quantize, encrypt and serialize the inputs
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
inputs,
input_index=INPUT_INDEXES[client_type],
processed_input_shape=APPROVAL_PROCESSED_INPUT_SHAPE,
input_slice=APPROVAL_INPUT_SLICES[client_type],
)
file_name = "encrypted_inputs"
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
# buttons, https://github.com/gradio-app/gradio/issues/1877
encrypted_inputs_path = _get_client_file_path(file_name, client_id, client_type)
with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
encrypted_inputs_file.write(encrypted_inputs)
# Create a truncated version of the encrypted inputs for display
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
_send_to_server(client_id, client_type, file_name)
return encrypted_inputs_short
def _pre_process_user(*inputs):
"""Pre-process the user inputs.
Args:
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(numpy.ndarray): The pre-processed inputs.
"""
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
family_status, occupation_type, housing_type = inputs
# Retrieve boolean values
own_car = "Car" in bool_inputs
own_property = "Property" in bool_inputs
mobile_phone = "Mobile phone" in bool_inputs
user_inputs = pandas.DataFrame({
"Own_car": [own_car],
"Own_property": [own_property],
"Mobile_phone": [mobile_phone],
"Num_children": [num_children],
"Household_size": [household_size],
"Total_income": [total_income],
"Age": [age],
"Income_type": [income_type],
"Education_type": [education_type],
"Family_status": [family_status],
"Occupation_type": [occupation_type],
"Housing_type": [housing_type],
})
user_inputs = user_inputs.reindex(USER_COLUMNS, axis=1)
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
return preprocessed_user_inputs
def pre_process_encrypt_send_user(client_id, *inputs):
"""Pre-process, encrypt and send the user inputs for a specific client to the server.
Args:
client_id (str): The current client ID to consider.
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(str): A short representation of the encrypted input to send in hex.
"""
preprocessed_user_inputs = _pre_process_user(*inputs)
return _encrypt_send(client_id, preprocessed_user_inputs, "user")
def _pre_process_bank(*inputs):
"""Pre-process the bank inputs.
Args:
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(numpy.ndarray): The pre-processed inputs.
"""
account_age = inputs[0]
bank_inputs = pandas.DataFrame({
"Account_age": [account_age],
})
bank_inputs = bank_inputs.reindex(BANK_COLUMNS, axis=1)
preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs)
return preprocessed_bank_inputs
def pre_process_encrypt_send_bank(client_id, *inputs):
"""Pre-process, encrypt and send the bank inputs for a specific client to the server.
Args:
client_id (str): The current client ID to consider.
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(str): A short representation of the encrypted input to send in hex.
"""
preprocessed_bank_inputs = _pre_process_bank(*inputs)
return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
def _pre_process_third_party(*inputs):
"""Pre-process the third party inputs.
Args:
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(numpy.ndarray): The pre-processed inputs.
"""
third_party_data = {}
if len(inputs) == 1:
employed = inputs[0]
else:
employed, years_employed = inputs
third_party_data["Years_employed"] = [years_employed]
is_employed = employed == "Yes"
third_party_data["Employed"] = [is_employed]
third_party_inputs = pandas.DataFrame(third_party_data)
if len(inputs) == 1:
preprocessed_third_party_inputs = third_party_inputs.to_numpy()
else:
third_party_inputs = third_party_inputs.reindex(APPROVAL_THIRD_PARTY_COLUMNS, axis=1)
preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
return preprocessed_third_party_inputs
def pre_process_encrypt_send_third_party(client_id, *inputs):
"""Pre-process, encrypt and send the third party inputs for a specific client to the server.
Args:
client_id (str): The current client ID to consider.
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
Returns:
(str): A short representation of the encrypted input to send in hex.
"""
preprocessed_third_party_inputs = _pre_process_third_party(*inputs)
return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
def run_fhe(client_id):
"""Run the model on the encrypted inputs previously sent using FHE.
Args:
client_id (str): The current client ID to consider.
"""
if client_id == "":
raise gr.Error("Please generate the keys first.")
data = {
"client_id": client_id,
}
# Trigger the FHE execution on the encrypted inputs previously sent
url = SERVER_URL + "run_fhe"
with requests.post(
url=url,
data=data,
) as response:
if response.ok:
return response.json()
else:
raise gr.Error("Please send the inputs from all three parties to the server first.")
def get_output_and_decrypt(client_id):
"""Retrieve the encrypted output.
Args:
client_id (str): The current client ID to consider.
Returns:
(Tuple[str, bytes]): The output message based on the decrypted prediction as well as
a byte short representation of the encrypted output.
"""
if client_id == "":
raise gr.Error("Please generate the keys first.")
data = {
"client_id": client_id,
}
# Retrieve the encrypted output
url = SERVER_URL + "get_output"
with requests.post(
url=url,
data=data,
) as response:
if response.ok:
encrypted_output_proba = response.content
# Create a truncated version of the encrypted inputs for display
encrypted_output_short = shorten_bytes_object(encrypted_output_proba)
# Retrieve the client API
client = _get_client(client_id)
# Deserialize, decrypt and post-process the encrypted output
output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba)
# Determine the predicted class
output = numpy.argmax(output_proba, axis=1).squeeze()
return (
"Credit card is likely to be approved βœ…" if output == 1
else "Credit card is likely to be denied ❌",
encrypted_output_short,
)
else:
raise gr.Error("Please run the FHE execution first and wait for it to be completed.")
def years_employed_encrypt_run_decrypt(client_id, prediction_output, *inputs):
"""Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output.
Args:
client_id (str): The current client ID to consider.
prediction_output (str): The initial prediction output. This parameter is only used to
throw an error in case the prediction was positive.
*inputs (Tuple[numpy.ndarray]): The inputs to consider.
Returns:
(str): A message indicating the number of additional years of employment that could be
required in order to increase the chance of
credit card approval.
"""
if "approved" in prediction_output:
raise gr.Error(
"Explaining the prediction can only be done if the credit card is likely to be denied."
)
# Retrieve the client instance
client = _get_client(client_id, is_approval=False)
# Generate the private and evaluation keys
client.generate_private_and_evaluation_keys(force=False)
# Retrieve the serialized evaluation key
evaluation_key = client.get_serialized_evaluation_keys()
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
family_status, occupation_type, housing_type, account_age, employed, years_employed = inputs
preprocessed_user_inputs = _pre_process_user(
bool_inputs, num_children, household_size, total_income, age, income_type, education_type,
family_status, occupation_type, housing_type,
)
preprocessed_bank_inputs = _pre_process_bank(account_age)
preprocessed_third_party_inputs = _pre_process_third_party(employed)
preprocessed_inputs = [
preprocessed_user_inputs,
preprocessed_bank_inputs,
preprocessed_third_party_inputs
]
# Quantize, encrypt and serialize the inputs
encrypted_inputs = []
for client_type, preprocessed_input in zip(CLIENT_TYPES, preprocessed_inputs):
encrypted_input = client.quantize_encrypt_serialize_multi_inputs(
preprocessed_input,
input_index=INPUT_INDEXES[client_type],
processed_input_shape=EXPLAIN_PROCESSED_INPUT_SHAPE,
input_slice=EXPLAIN_INPUT_SLICES[client_type],
)
encrypted_inputs.append(encrypted_input)
# Run the FHE computation
encrypted_output = EXPLAIN_FHE_SERVER.run(
*encrypted_inputs,
serialized_evaluation_keys=evaluation_key
)
# Decrypt the output
output_prediction = client.deserialize_decrypt_dequantize(encrypted_output)
# Get the difference with the initial 'years of employment' input
years_employed_diff = int(numpy.ceil(output_prediction.squeeze() - years_employed))
if years_employed_diff > 0:
return (
f"Having at least {years_employed_diff} more years of employment would increase "
"your chance of having your credit card approved."
)
return (
"The number of years of employment you provided seems to be enough. The negative prediction "
"might come from other inputs."
)