Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
"""Backend functions used in the app.""" | |
import os | |
import shutil | |
import gradio as gr | |
import numpy | |
import requests | |
import pickle | |
import pandas | |
from itertools import chain | |
from settings import ( | |
SERVER_URL, | |
FHE_KEYS, | |
CLIENT_FILES, | |
SERVER_FILES, | |
DEPLOYMENT_PATH, | |
INITIAL_INPUT_SHAPE, | |
INPUT_INDEXES, | |
INPUT_SLICES, | |
PRE_PROCESSOR_USER_PATH, | |
PRE_PROCESSOR_THIRD_PARTY_PATH, | |
CLIENT_TYPES, | |
USER_COLUMNS, | |
THIRD_PARTY_COLUMNS, | |
) | |
from utils.client_server_interface import MultiInputsFHEModelClient | |
# Load pre-processor instances | |
with PRE_PROCESSOR_USER_PATH.open('rb') as file: | |
PRE_PROCESSOR_USER = pickle.load(file) | |
with PRE_PROCESSOR_THIRD_PARTY_PATH.open('rb') as file: | |
PRE_PROCESSOR_THIRD_PARTY = pickle.load(file) | |
def shorten_bytes_object(bytes_object, limit=500): | |
"""Shorten the input bytes object to a given length. | |
Encrypted data is too large for displaying it in the browser using Gradio. This function | |
provides a shorten representation of it. | |
Args: | |
bytes_object (bytes): The input to shorten | |
limit (int): The length to consider. Default to 500. | |
Returns: | |
str: Hexadecimal string shorten representation of the input byte object. | |
""" | |
# Define a shift for better display | |
shift = 100 | |
return bytes_object[shift : limit + shift].hex() | |
def clean_temporary_files(n_keys=20): | |
"""Clean older keys and encrypted files. | |
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this | |
limit is reached, the oldest files are deleted. | |
Args: | |
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. | |
""" | |
# Get the oldest key files in the key directory | |
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) | |
# If more than n_keys keys are found, remove the oldest | |
client_ids = [] | |
if len(key_dirs) > n_keys: | |
n_keys_to_delete = len(key_dirs) - n_keys | |
for key_dir in key_dirs[:n_keys_to_delete]: | |
client_ids.append(key_dir.name) | |
shutil.rmtree(key_dir) | |
# Delete all files related to the IDs whose keys were deleted | |
for directory in chain(CLIENT_FILES.iterdir(), SERVER_FILES.iterdir()): | |
for client_id in client_ids: | |
if client_id in directory.name: | |
shutil.rmtree(directory) | |
def _get_client(client_id): | |
"""Get the client instance. | |
Args: | |
client_id (int): The client ID to consider. | |
Returns: | |
FHEModelClient: The client instance. | |
""" | |
key_dir = FHE_KEYS / f"{client_id}" | |
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES)) | |
def _get_client_file_path(name, client_id, client_type=None): | |
"""Get the file path for the client. | |
Args: | |
name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or | |
'encrypted_outputs'). | |
client_id (int): The client ID to consider. | |
client_type (Optional[str]): The type of user to consider (either 'user', 'bank', | |
'third_party' or None). Default to None, which is used for evaluation key and output. | |
Returns: | |
pathlib.Path: The file path. | |
""" | |
client_type_suffix = "" | |
if client_type is not None: | |
client_type_suffix = f"_{client_type}" | |
dir_path = CLIENT_FILES / f"{client_id}" | |
dir_path.mkdir(exist_ok=True) | |
return dir_path / f"{name}{client_type_suffix}" | |
def _send_to_server(client_id, client_type, file_name): | |
"""Send the encrypted inputs or the evaluation key to the server. | |
Args: | |
client_id (int): The client ID to consider. | |
client_type (Optional[str]): The type of client to consider (either 'user', 'bank', 'third_party' or | |
None). | |
file_name (str): File name to send (either 'evaluation_key' or 'encrypted_inputs'). | |
""" | |
# Get the paths to the encrypted inputs | |
encrypted_file_path = _get_client_file_path(file_name, client_id, client_type) | |
# Define the data and files to post | |
data = { | |
"client_id": client_id, | |
"client_type": client_type, | |
"file_name": file_name, | |
} | |
files = [ | |
("files", open(encrypted_file_path, "rb")), | |
] | |
# Send the encrypted inputs or evaluation key to the server | |
url = SERVER_URL + "send_file" | |
with requests.post( | |
url=url, | |
data=data, | |
files=files, | |
) as response: | |
return response.ok | |
def keygen_send(): | |
"""Generate the private and evaluation key, and send the evaluation key to the server. | |
Returns: | |
client_id (str): The current client ID to consider. | |
""" | |
# Clean temporary files | |
clean_temporary_files() | |
# Create an ID for the current client to consider | |
client_id = numpy.random.randint(0, 2**32) | |
# Retrieve the client instance | |
client = _get_client(client_id) | |
# Generate the private and evaluation keys | |
client.generate_private_and_evaluation_keys(force=True) | |
# Retrieve the serialized evaluation key | |
evaluation_key = client.get_serialized_evaluation_keys() | |
file_name = "evaluation_key" | |
# Save evaluation key as bytes in a file as it is too large to pass through regular Gradio | |
# buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
evaluation_key_path = _get_client_file_path(file_name, client_id) | |
with evaluation_key_path.open("wb") as evaluation_key_file: | |
evaluation_key_file.write(evaluation_key) | |
# Send the evaluation key to the server | |
_send_to_server(client_id, None, file_name) | |
# Create a truncated version of the evaluation key for display | |
evaluation_key_short = shorten_bytes_object(evaluation_key) | |
return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent β ") | |
def _encrypt_send(client_id, inputs, client_type): | |
"""Encrypt the given inputs for a specific client and send it to the server. | |
Args: | |
client_id (str): The current client ID to consider. | |
inputs (numpy.ndarray): The inputs to encrypt. | |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). | |
Returns: | |
client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client | |
and a byte short representation of the encrypted input to send. | |
""" | |
if client_id == "": | |
raise gr.Error("Please generate the keys first.") | |
# Retrieve the client instance | |
client = _get_client(client_id) | |
# Quantize, encrypt and serialize the inputs | |
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( | |
inputs, | |
input_index=INPUT_INDEXES[client_type], | |
initial_input_shape=INITIAL_INPUT_SHAPE, | |
input_slice=INPUT_SLICES[client_type], | |
) | |
file_name = "encrypted_inputs" | |
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio | |
# buttons, https://github.com/gradio-app/gradio/issues/1877 | |
encrypted_inputs_path = _get_client_file_path(file_name, client_id, client_type) | |
with encrypted_inputs_path.open("wb") as encrypted_inputs_file: | |
encrypted_inputs_file.write(encrypted_inputs) | |
# Create a truncated version of the encrypted inputs for display | |
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) | |
_send_to_server(client_id, client_type, file_name) | |
return encrypted_inputs_short | |
def pre_process_encrypt_send_user(client_id, *inputs): | |
"""Pre-process, encrypt and send the user inputs for a specific client to the server. | |
Args: | |
client_id (str): The current client ID to consider. | |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
Returns: | |
(int, bytes): Integer ID representing the current client and a byte short representation of | |
the encrypted input to send. | |
""" | |
bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \ | |
family_status, occupation_type, housing_type = inputs | |
# Retrieve boolean values | |
own_car = "Car" in bool_inputs | |
own_property = "Property" in bool_inputs | |
work_phone = "Work phone" in bool_inputs | |
phone = "Phone" in bool_inputs | |
email = "Email" in bool_inputs | |
user_inputs = pandas.DataFrame({ | |
"Own_car": [own_car], | |
"Own_property": [own_property], | |
"Work_phone": [work_phone], | |
"Phone": [phone], | |
"Email": [email], | |
"Num_children": [num_children], | |
"Household_size": [household_size], | |
"Total_income": [total_income], | |
"Age": [age], | |
"Income_type": [income_type], | |
"Education_type": [education_type], | |
"Family_status": [family_status], | |
"Occupation_type": [occupation_type], | |
"Housing_type": [housing_type], | |
}) | |
user_inputs = user_inputs.reindex(USER_COLUMNS, axis=1) | |
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs) | |
return _encrypt_send(client_id, preprocessed_user_inputs, "user") | |
def pre_process_encrypt_send_bank(client_id, *inputs): | |
"""Pre-process, encrypt and send the bank inputs for a specific client to the server. | |
Args: | |
client_id (str): The current client ID to consider. | |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
Returns: | |
(int, bytes): Integer ID representing the current client and a byte short representation of | |
the encrypted input to send. | |
""" | |
account_length = inputs[0] | |
return _encrypt_send(client_id, account_length, "bank") | |
def pre_process_encrypt_send_third_party(client_id, *inputs): | |
"""Pre-process, encrypt and send the third party inputs for a specific client to the server. | |
Args: | |
client_id (str): The current client ID to consider. | |
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process. | |
Returns: | |
(int, bytes): Integer ID representing the current client and a byte short representation of | |
the encrypted input to send. | |
""" | |
salaried, years_salaried = inputs | |
is_salaried = salaried == "Yes" | |
third_party_inputs = pandas.DataFrame({ | |
"Salaried": [is_salaried], | |
"Years_employed": [years_salaried], | |
}) | |
third_party_inputs = third_party_inputs.reindex(THIRD_PARTY_COLUMNS, axis=1) | |
preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs) | |
return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party") | |
def run_fhe(client_id): | |
"""Run the model on the encrypted inputs previously sent using FHE. | |
Args: | |
client_id (str): The current client ID to consider. | |
""" | |
if client_id == "": | |
raise gr.Error("Please generate the keys first.") | |
data = { | |
"client_id": client_id, | |
} | |
# Trigger the FHE execution on the encrypted inputs previously sent | |
url = SERVER_URL + "run_fhe" | |
with requests.post( | |
url=url, | |
data=data, | |
) as response: | |
if response.ok: | |
return response.json() | |
else: | |
raise gr.Error("Please send the inputs from all three parties to the server first.") | |
def get_output(client_id): | |
"""Retrieve the encrypted output. | |
Args: | |
client_id (str): The current client ID to consider. | |
Returns: | |
encrypted_output_short (bytes): A byte short representation of the encrypted output. | |
""" | |
if client_id == "": | |
raise gr.Error("Please generate the keys first.") | |
data = { | |
"client_id": client_id, | |
} | |
# Retrieve the encrypted output | |
url = SERVER_URL + "get_output" | |
with requests.post( | |
url=url, | |
data=data, | |
) as response: | |
if response.ok: | |
encrypted_output = response.content | |
# Save the encrypted output to bytes in a file as it is too large to pass through regular | |
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
encrypted_output_path = _get_client_file_path("encrypted_output", client_id) | |
with encrypted_output_path.open("wb") as encrypted_output_file: | |
encrypted_output_file.write(encrypted_output) | |
# Create a truncated version of the encrypted inputs for display | |
encrypted_output_short = shorten_bytes_object(encrypted_output) | |
return encrypted_output_short | |
else: | |
raise gr.Error("Please run the FHE execution first and wait for it to be completed.") | |
def decrypt_output(client_id): | |
"""Decrypt the result. | |
Args: | |
client_id (str): The current client ID to consider. | |
Returns: | |
output(numpy.ndarray): The decrypted output | |
""" | |
if client_id == "": | |
raise gr.Error("Please generate the keys first.") | |
# Get the encrypted output path | |
encrypted_output_path = _get_client_file_path("encrypted_output", client_id) | |
if not encrypted_output_path.is_file(): | |
raise gr.Error("Please receive the outputs from the server first.") | |
# Load the encrypted output as bytes | |
with encrypted_output_path.open("rb") as encrypted_output_file: | |
encrypted_output_proba = encrypted_output_file.read() | |
# Retrieve the client API | |
client = _get_client(client_id) | |
# Deserialize, decrypt and post-process the encrypted output | |
output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba) | |
# Determine the predicted class | |
output = numpy.argmax(output_proba, axis=1).squeeze() | |
# A "0" output means approving the credit card has low risk, while "1" is high risk | |
return "Credit card has been approved β " if output == 0 else "Credit card has been rejected β" |