Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,059 Bytes
9a997e4 c119738 9a997e4 c119738 74c0c8e c119738 b0303a0 c119738 9a997e4 b0303a0 74c0c8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""Modified model class to handles multi-inputs circuit."""
import numpy
import time
from typing import Optional, Sequence, Union
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
from concrete.ml.common.check_inputs import check_array_and_assert
from concrete.ml.common.utils import (
generate_proxy_function,
manage_parameters_for_pbs_errors,
check_there_is_no_p_error_options_in_configuration
)
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
from concrete.ml.sklearn import DecisionTreeClassifier, DecisionTreeRegressor
class MultiInputModel:
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
self.check_model_is_fitted()
assert sum(input.shape[1] for input in X) == len(self.input_quantizers)
base_j = 0
q_inputs = []
for i, input in enumerate(X):
q_input = numpy.zeros_like(input, dtype=numpy.int64)
for j in range(input.shape[1]):
quantizer_index = base_j + j
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j])
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values"
q_inputs.append(q_input)
base_j += input.shape[1]
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0]
def compile(
self,
*inputs,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> Circuit:
# Check that the model is correctly fitted
self.check_model_is_fitted()
# Cast pandas, list or torch to numpy
inputs_as_array = []
for input in inputs:
input_as_array = check_array_and_assert(input)
inputs_as_array.append(input_as_array)
inputs_as_array = tuple(inputs_as_array)
# p_error or global_p_error should not be set in both the configuration and direct arguments
check_there_is_no_p_error_options_in_configuration(configuration)
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
# Quantize the inputs
quantized_inputs = self.quantize_input(*inputs_as_array)
# Generate the compilation input-set with proper dimensions
inputset = _get_inputset_generator(quantized_inputs)
# Reset for double compile
self._is_compiled = False
# Retrieve the compiler instance
module_to_compile = self._get_module_to_compile(inputs_encryption_status)
# Compiling using a QuantizedModule requires different steps and should not be done here
assert isinstance(module_to_compile, Compiler), (
"Wrong module to compile. Expected to be of type `Compiler` but got "
f"{type(module_to_compile)}."
)
# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
jit=False,
)
self._is_compiled = True
# For mypy
assert isinstance(self.fhe_circuit, Circuit)
return self.fhe_circuit
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]:
assert self._tree_inference is not None, self._is_not_fitted_error_message()
if not self._is_compiled:
xgb_inference = self._tree_inference
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1))
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))]
# Generate the proxy function to compile
_tree_inference_proxy, function_arg_names = generate_proxy_function(
self._tree_inference, input_names
)
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)}
# Create the compiler instance
compiler = Compiler(
_tree_inference_proxy,
inputs_encryption_statuses,
)
return compiler
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
pass
class MultiInputDecisionTreeRegressor(MultiInputModel, DecisionTreeRegressor):
pass
|