"""Modified model class to handles multi-inputs circuit.""" import numpy import time from typing import Optional, Sequence, Union from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit from concrete.ml.common.check_inputs import check_array_and_assert from concrete.ml.common.utils import ( generate_proxy_function, manage_parameters_for_pbs_errors, check_there_is_no_p_error_options_in_configuration ) from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier class MultiInputXGBClassifier(ConcreteXGBClassifier): def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray: self.check_model_is_fitted() assert sum(input.shape[1] for input in X) == len(self.input_quantizers) base_j = 0 q_inputs = [] for i, input in enumerate(X): q_input = numpy.zeros_like(input, dtype=numpy.int64) for j in range(input.shape[1]): quantizer_index = base_j + j q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j]) assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values" q_inputs.append(q_input) base_j += input.shape[1] return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0] def compile( self, *inputs, configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, inputs_encryption_status: Optional[Sequence[str]] = None, ) -> Circuit: # Check that the model is correctly fitted self.check_model_is_fitted() # Cast pandas, list or torch to numpy inputs_as_array = [] for input in inputs: input_as_array = check_array_and_assert(input) inputs_as_array.append(input_as_array) inputs_as_array = tuple(inputs_as_array) # p_error or global_p_error should not be set in both the configuration and direct arguments check_there_is_no_p_error_options_in_configuration(configuration) # Find the right way to set parameters for compiler, depending on the way we want to default p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error) # Quantize the inputs quantized_inputs = self.quantize_input(*inputs_as_array) # Generate the compilation input-set with proper dimensions inputset = _get_inputset_generator(quantized_inputs) # Reset for double compile self._is_compiled = False # Retrieve the compiler instance module_to_compile = self._get_module_to_compile(inputs_encryption_status) # Compiling using a QuantizedModule requires different steps and should not be done here assert isinstance(module_to_compile, Compiler), ( "Wrong module to compile. Expected to be of type `Compiler` but got " f"{type(module_to_compile)}." ) # Jit compiler is now deprecated and will soon be removed, it is thus forced to False # by default self.fhe_circuit_ = module_to_compile.compile( inputset, configuration=configuration, artifacts=artifacts, show_mlir=show_mlir, p_error=p_error, global_p_error=global_p_error, verbose=verbose, single_precision=False, fhe_simulation=False, fhe_execution=True, jit=False, ) self._is_compiled = True # For mypy assert isinstance(self.fhe_circuit, Circuit) return self.fhe_circuit def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]: assert self._tree_inference is not None, self._is_not_fitted_error_message() if not self._is_compiled: xgb_inference = self._tree_inference self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1)) input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))] # Generate the proxy function to compile _tree_inference_proxy, function_arg_names = generate_proxy_function( self._tree_inference, input_names ) inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)} # Create the compiler instance compiler = Compiler( _tree_inference_proxy, inputs_encryption_statuses, ) return compiler def predict_multi_inputs(self, *multi_inputs, simulate=True): """Run the inference with multiple inputs, with simulation or in FHE.""" assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs) if not simulate: self.fhe_circuit.keygen() y_preds = [] execution_times = [] for inputs in zip(*multi_inputs): inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs) q_inputs = self.quantize_input(*inputs) if simulate: q_y_proba = self.fhe_circuit.simulate(*q_inputs) else: q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs) start = time.time() q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc) end = time.time() - start execution_times.append(end) q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc) y_proba = self.dequantize_output(q_y_proba) y_proba = self.post_processing(y_proba) y_pred = numpy.argmax(y_proba, axis=1) y_preds.append(y_pred) if not simulate: print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s") return numpy.array(y_preds)