encrypted_image_filtering / custom_client_server.py
Roman
chore: Add comments, clean unused objects and improve ridge detection
3cf0931 unverified
raw
history blame
6.83 kB
"Client-server interface implementation for custom integer models."
from pathlib import Path
from typing import Any
import concrete.numpy as cnp
import numpy as np
from filters import Filter
from concrete.ml.common.debugging.custom_assert import assert_true
class CustomFHEDev:
"""Dev API to save the custom integer model, load and run a FHE circuit."""
model: Any = None
def __init__(self, path_dir: str, model: Any = None):
"""Initialize the development interface.
Args:
path_dir (str): The path to the directory where the circuit is saved.
model (Any): The model to use for the development interface.
"""
self.path_dir = Path(path_dir)
self.model = model
# Create the directory path if it does not exist yet
Path(self.path_dir).mkdir(parents=True, exist_ok=True)
def save(self):
"""Export all needed artifacts for the client and server.
Raises:
Exception: path_dir is not empty.
"""
# Check if the path_dir is empty with pathlib
listdir = list(Path(self.path_dir).glob("**/*"))
if len(listdir) > 0:
raise Exception(
f"path_dir: {self.path_dir} is not empty."
"Please delete it before saving a new model."
)
assert_true(
hasattr(self.model, "fhe_circuit"),
"The model must be compiled and have a fhe_circuit object",
)
# Model must be compiled with jit=False
# In a jit model, everything is in memory so it is not serializable.
assert_true(
not self.model.fhe_circuit.configuration.jit,
"The model must be compiled with the configuration option jit=False.",
)
# Export the parameters
self.model.to_json(path_dir=self.path_dir, file_name="serialized_processing")
# Save the circuit for the server
path_circuit_server = self.path_dir / "server.zip"
self.model.fhe_circuit.server.save(path_circuit_server)
# Save the circuit for the client
path_circuit_client = self.path_dir / "client.zip"
self.model.fhe_circuit.client.save(path_circuit_client)
class CustomFHEClient:
"""Client API to encrypt and decrypt FHE data."""
client: cnp.Client
def __init__(self, path_dir: str, key_dir: str = None):
"""Initialize the client interface.
Args:
path_dir (str): The path to the directory where the circuit is saved.
key_dir (str): The path to the directory where the keys are stored.
"""
self.path_dir = Path(path_dir)
self.key_dir = Path(key_dir)
# If path_dir does not exist, raise an error
assert_true(
Path(path_dir).exists(), f"{path_dir} does not exist. Please specify a valid path."
)
# Load
self.load()
def load(self): # pylint: disable=no-value-for-parameter
"""Load the parameters along with the FHE specs."""
# Load the client
self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
# Load the model
self.model = Filter.from_json(self.path_dir / "serialized_processing.json")
def generate_private_and_evaluation_keys(self, force=False):
"""Generate the private and evaluation keys.
Args:
force (bool): If True, regenerate the keys even if they already exist.
"""
self.client.keygen(force)
def get_serialized_evaluation_keys(self) -> cnp.EvaluationKeys:
"""Get the serialized evaluation keys.
Returns:
cnp.EvaluationKeys: The evaluation keys.
"""
return self.client.evaluation_keys.serialize()
def pre_process_encrypt_serialize(self, x: np.ndarray) -> cnp.PublicArguments:
"""Encrypt and serialize the values.
Args:
x (numpy.ndarray): The values to encrypt and serialize.
Returns:
cnp.PublicArguments: The encrypted and serialized values.
"""
# Pre-process the values
x = self.model.pre_processing(x)
# Encrypt the values
enc_x = self.client.encrypt(x)
# Serialize the encrypted values to be sent to the server
serialized_enc_x = self.client.specs.serialize_public_args(enc_x)
return serialized_enc_x
def deserialize_decrypt_post_process(
self, serialized_encrypted_output: cnp.PublicArguments
) -> np.ndarray:
"""Deserialize, decrypt and post-process the values.
Args:
serialized_encrypted_output (cnp.PublicArguments): The serialized and encrypted output.
Returns:
numpy.ndarray: The decrypted values.
"""
# Deserialize the encrypted values
deserialized_encrypted_output = self.client.specs.unserialize_public_result(
serialized_encrypted_output
)
# Decrypt the values
deserialized_decrypted_output = self.client.decrypt(deserialized_encrypted_output)
# Apply the model post processing
deserialized_decrypted_output = self.model.post_processing(deserialized_decrypted_output)
return deserialized_decrypted_output
class CustomFHEServer:
"""Server interface to load and run a FHE circuit."""
server: cnp.Server
def __init__(self, path_dir: str):
"""Initialize the server interface.
Args:
path_dir (str): The path to the directory where the circuit is saved.
"""
self.path_dir = Path(path_dir)
# Load the FHE circuit
self.load()
def load(self):
"""Load the circuit."""
self.server = cnp.Server.load(self.path_dir / "server.zip")
def run(
self,
serialized_encrypted_data: cnp.PublicArguments,
serialized_evaluation_keys: cnp.EvaluationKeys,
) -> cnp.PublicResult:
"""Run the model on the server over encrypted data.
Args:
serialized_encrypted_data (cnp.PublicArguments): The encrypted and serialized data.
serialized_evaluation_keys (cnp.EvaluationKeys): The serialized evaluation keys.
Returns:
cnp.PublicResult: The result of the model.
"""
assert_true(self.server is not None, "Model has not been loaded.")
deserialized_encrypted_data = self.server.client_specs.unserialize_public_args(
serialized_encrypted_data
)
deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)
result = self.server.run(deserialized_encrypted_data, deserialized_evaluation_keys)
serialized_result = self.server.client_specs.serialize_public_result(result)
return serialized_result