File size: 4,925 Bytes
9a997e4
 
c119738
9a997e4
c119738
 
 
 
 
 
 
 
 
 
 
615cfe4
c119738
b0303a0
c119738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a7fbc
 
c119738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a997e4
b0303a0
615cfe4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Modified model class to handles multi-inputs circuit."""

import numpy
import time
from typing import Optional, Sequence, Union

from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit

from concrete.ml.common.check_inputs import check_array_and_assert
from concrete.ml.common.utils import (
    generate_proxy_function, 
    manage_parameters_for_pbs_errors, 
    check_there_is_no_p_error_options_in_configuration
)
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
from concrete.ml.sklearn import DecisionTreeClassifier

class MultiInputModel:

    def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
        self.check_model_is_fitted()
        assert sum(input.shape[1] for input in X) == len(self.input_quantizers)

        base_j = 0
        q_inputs = []
        for i, input in enumerate(X):
            q_input = numpy.zeros_like(input, dtype=numpy.int64)

            for j in range(input.shape[1]):
                quantizer_index = base_j + j
                q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j])

            assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values"

            q_inputs.append(q_input)
            base_j += input.shape[1]

        return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0]

    def compile(
        self,
        *inputs,
        configuration: Optional[Configuration] = None,
        artifacts: Optional[DebugArtifacts] = None,
        show_mlir: bool = False,
        p_error: Optional[float] = None,
        global_p_error: Optional[float] = None,
        verbose: bool = False,
        inputs_encryption_status: Optional[Sequence[str]] = None,
    ) -> Circuit:

        # Check that the model is correctly fitted
        self.check_model_is_fitted()

        # Cast pandas, list or torch to numpy
        inputs_as_array = []
        for input in inputs:
            input_as_array = check_array_and_assert(input)
            inputs_as_array.append(input_as_array)

        inputs_as_array = tuple(inputs_as_array)

        # p_error or global_p_error should not be set in both the configuration and direct arguments
        check_there_is_no_p_error_options_in_configuration(configuration)

        # Find the right way to set parameters for compiler, depending on the way we want to default
        p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)

        # Quantize the inputs
        quantized_inputs = self.quantize_input(*inputs_as_array)

        # Generate the compilation input-set with proper dimensions
        inputset = _get_inputset_generator(quantized_inputs)

        # Reset for double compile
        self._is_compiled = False

        # Retrieve the compiler instance
        module_to_compile = self._get_module_to_compile(inputs_encryption_status)

        # Compiling using a QuantizedModule requires different steps and should not be done here
        assert isinstance(module_to_compile, Compiler), (
            "Wrong module to compile. Expected to be of type `Compiler` but got "
            f"{type(module_to_compile)}."
        )

        # Jit compiler is now deprecated and will soon be removed, it is thus forced to False
        # by default
        self.fhe_circuit_ = module_to_compile.compile(
            inputset,
            configuration=configuration,
            artifacts=artifacts,
            show_mlir=show_mlir,
            p_error=p_error,
            global_p_error=global_p_error,
            verbose=verbose,
            single_precision=False,
            fhe_simulation=False,
            fhe_execution=True,
        )

        self._is_compiled = True

        # For mypy
        assert isinstance(self.fhe_circuit, Circuit)

        return self.fhe_circuit

    def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]:
        assert self._tree_inference is not None, self._is_not_fitted_error_message()

        if not self._is_compiled:
            tree_inference = self._tree_inference
            self._tree_inference = lambda *args: tree_inference(numpy.concatenate(args, axis=1))

        input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))]

        # Generate the proxy function to compile
        _tree_inference_proxy, function_arg_names = generate_proxy_function(
            self._tree_inference, input_names
        )

        inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)}

        # Create the compiler instance
        compiler = Compiler(
            _tree_inference_proxy,
            inputs_encryption_statuses,
        )

        return compiler

class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
    pass