Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
"""Modified model class to handles multi-inputs circuit.""" | |
import numpy | |
import time | |
from typing import Optional, Sequence, Union | |
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit | |
from concrete.ml.common.check_inputs import check_array_and_assert | |
from concrete.ml.common.utils import ( | |
generate_proxy_function, | |
manage_parameters_for_pbs_errors, | |
check_there_is_no_p_error_options_in_configuration | |
) | |
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator | |
from concrete.ml.sklearn import DecisionTreeClassifier | |
class MultiInputModel: | |
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray: | |
self.check_model_is_fitted() | |
assert sum(input.shape[1] for input in X) == len(self.input_quantizers) | |
base_j = 0 | |
q_inputs = [] | |
for i, input in enumerate(X): | |
q_input = numpy.zeros_like(input, dtype=numpy.int64) | |
for j in range(input.shape[1]): | |
quantizer_index = base_j + j | |
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j]) | |
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values" | |
q_inputs.append(q_input) | |
base_j += input.shape[1] | |
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0] | |
def compile( | |
self, | |
*inputs, | |
configuration: Optional[Configuration] = None, | |
artifacts: Optional[DebugArtifacts] = None, | |
show_mlir: bool = False, | |
p_error: Optional[float] = None, | |
global_p_error: Optional[float] = None, | |
verbose: bool = False, | |
inputs_encryption_status: Optional[Sequence[str]] = None, | |
) -> Circuit: | |
# Check that the model is correctly fitted | |
self.check_model_is_fitted() | |
# Cast pandas, list or torch to numpy | |
inputs_as_array = [] | |
for input in inputs: | |
input_as_array = check_array_and_assert(input) | |
inputs_as_array.append(input_as_array) | |
inputs_as_array = tuple(inputs_as_array) | |
# p_error or global_p_error should not be set in both the configuration and direct arguments | |
check_there_is_no_p_error_options_in_configuration(configuration) | |
# Find the right way to set parameters for compiler, depending on the way we want to default | |
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error) | |
# Quantize the inputs | |
quantized_inputs = self.quantize_input(*inputs_as_array) | |
# Generate the compilation input-set with proper dimensions | |
inputset = _get_inputset_generator(quantized_inputs) | |
# Reset for double compile | |
self._is_compiled = False | |
# Retrieve the compiler instance | |
module_to_compile = self._get_module_to_compile(inputs_encryption_status) | |
# Compiling using a QuantizedModule requires different steps and should not be done here | |
assert isinstance(module_to_compile, Compiler), ( | |
"Wrong module to compile. Expected to be of type `Compiler` but got " | |
f"{type(module_to_compile)}." | |
) | |
# Jit compiler is now deprecated and will soon be removed, it is thus forced to False | |
# by default | |
self.fhe_circuit_ = module_to_compile.compile( | |
inputset, | |
configuration=configuration, | |
artifacts=artifacts, | |
show_mlir=show_mlir, | |
p_error=p_error, | |
global_p_error=global_p_error, | |
verbose=verbose, | |
single_precision=False, | |
fhe_simulation=False, | |
fhe_execution=True, | |
jit=False, | |
) | |
self._is_compiled = True | |
# For mypy | |
assert isinstance(self.fhe_circuit, Circuit) | |
return self.fhe_circuit | |
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]: | |
assert self._tree_inference is not None, self._is_not_fitted_error_message() | |
if not self._is_compiled: | |
xgb_inference = self._tree_inference | |
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1)) | |
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))] | |
# Generate the proxy function to compile | |
_tree_inference_proxy, function_arg_names = generate_proxy_function( | |
self._tree_inference, input_names | |
) | |
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)} | |
# Create the compiler instance | |
compiler = Compiler( | |
_tree_inference_proxy, | |
inputs_encryption_statuses, | |
) | |
return compiler | |
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier): | |
pass |