romanbredehoft-zama's picture
Remove unused imports
615cfe4
raw
history blame
4.95 kB
"""Modified model class to handles multi-inputs circuit."""
import numpy
import time
from typing import Optional, Sequence, Union
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
from concrete.ml.common.check_inputs import check_array_and_assert
from concrete.ml.common.utils import (
generate_proxy_function,
manage_parameters_for_pbs_errors,
check_there_is_no_p_error_options_in_configuration
)
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
from concrete.ml.sklearn import DecisionTreeClassifier
class MultiInputModel:
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
self.check_model_is_fitted()
assert sum(input.shape[1] for input in X) == len(self.input_quantizers)
base_j = 0
q_inputs = []
for i, input in enumerate(X):
q_input = numpy.zeros_like(input, dtype=numpy.int64)
for j in range(input.shape[1]):
quantizer_index = base_j + j
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j])
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values"
q_inputs.append(q_input)
base_j += input.shape[1]
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0]
def compile(
self,
*inputs,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> Circuit:
# Check that the model is correctly fitted
self.check_model_is_fitted()
# Cast pandas, list or torch to numpy
inputs_as_array = []
for input in inputs:
input_as_array = check_array_and_assert(input)
inputs_as_array.append(input_as_array)
inputs_as_array = tuple(inputs_as_array)
# p_error or global_p_error should not be set in both the configuration and direct arguments
check_there_is_no_p_error_options_in_configuration(configuration)
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
# Quantize the inputs
quantized_inputs = self.quantize_input(*inputs_as_array)
# Generate the compilation input-set with proper dimensions
inputset = _get_inputset_generator(quantized_inputs)
# Reset for double compile
self._is_compiled = False
# Retrieve the compiler instance
module_to_compile = self._get_module_to_compile(inputs_encryption_status)
# Compiling using a QuantizedModule requires different steps and should not be done here
assert isinstance(module_to_compile, Compiler), (
"Wrong module to compile. Expected to be of type `Compiler` but got "
f"{type(module_to_compile)}."
)
# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
jit=False,
)
self._is_compiled = True
# For mypy
assert isinstance(self.fhe_circuit, Circuit)
return self.fhe_circuit
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]:
assert self._tree_inference is not None, self._is_not_fitted_error_message()
if not self._is_compiled:
xgb_inference = self._tree_inference
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1))
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))]
# Generate the proxy function to compile
_tree_inference_proxy, function_arg_names = generate_proxy_function(
self._tree_inference, input_names
)
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)}
# Create the compiler instance
compiler = Compiler(
_tree_inference_proxy,
inputs_encryption_statuses,
)
return compiler
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
pass