"""Modified classes for use for Client-Server interface with multi-inputs circuits.""" import numpy import copy from concrete.fhe import Value, EvaluationKeys from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier class MultiInputsFHEModelDev(FHEModelDev): def __init__(self, *arg, **kwargs): super().__init__(*arg, **kwargs) model = copy.copy(self.model) model.__class__ = ConcreteXGBClassifier self.model = model class MultiInputsFHEModelClient(FHEModelClient): def __init__(self, *args, nb_inputs=1, **kwargs): self.nb_inputs = nb_inputs super().__init__(*args, **kwargs) def quantize_encrypt_serialize_multi_inputs( self, x: numpy.ndarray, input_index, processed_input_shape, input_slice ) -> bytes: x_padded = numpy.zeros(processed_input_shape) x_padded[:, input_slice] = x q_x_padded = self.model.quantize_input(x_padded) q_x = q_x_padded[:, input_slice] q_x_inputs = [None for _ in range(self.nb_inputs)] q_x_inputs[input_index] = q_x # Encrypt the values q_x_enc = self.client.encrypt(*q_x_inputs) # Serialize the encrypted values to be sent to the server q_x_enc_ser = q_x_enc[input_index].serialize() return q_x_enc_ser class MultiInputsFHEModelServer(FHEModelServer): def run( self, *serialized_encrypted_quantized_data: bytes, serialized_evaluation_keys: bytes, ) -> bytes: """Run the model on the server over encrypted data. Args: serialized_encrypted_quantized_data (bytes): the encrypted, quantized and serialized data serialized_evaluation_keys (bytes): the serialized evaluation keys Returns: bytes: the result of the model """ assert self.server is not None, "Model has not been loaded." deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data) deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys) result = self.server.run( *deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys ) serialized_result = result.serialize() return serialized_result