File size: 2,790 Bytes
0908a41 fa707a9 0908a41 fa707a9 0908a41 fa707a9 0908a41 fa707a9 0908a41 fa707a9 0908a41 fa707a9 0908a41 fa707a9 0908a41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import random
import numpy as np
from torch import nn
import torch
from concrete.fhe.compilation.compiler import Compiler
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.torch.numpy_module import NumpyModule
from common import AVAILABLE_MATCHERS
class TorchRandomGuessing(nn.Module):
"""Torch identity model."""
def __init__(self, classes_=[0, 1]):
super().__init__()
self.classes_ = classes_
def forward(self, x):
"""Random guessing forward pass.
Args:
x (torch.Tensor): concat of query and reference.
Returns:
(torch.Tensor): .
"""
x = x.sum()
return torch.tensor([random.choice([0, 1])]) + x - x
class Matcher:
def __init__(self, matcher_name):
assert matcher_name in AVAILABLE_MATCHERS, (
f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, "
f"but got {matcher_name}",
)
self.fhe_circuit = None
self.matcher_name = matcher_name
if self.matcher_name == "random guessing":
self.torch_model = TorchRandomGuessing()
def compile(self):
inputset = (np.array([10]), np.array([5]))
print("torch module > numpy module ...")
numpy_module = NumpyModule(
# torch_model, dummy_input=torch.from_numpy(np.array([10], dtype=np.int64))
self.torch_model,
# dummy_input=(torch.tensor([10]), torch.tensor([5])),
dummy_input=torch.from_numpy(inputset[0]),
)
print("get proxy function ...")
# Get the proxy function and parameter mappings used for initializing the compiler
# This is done in order to be able to provide any modules with arbitrary numbers of
# encrypted arguments to Concrete Numpy's compiler
numpy_filter_proxy, parameters_mapping = generate_proxy_function(
numpy_module.numpy_forward, ["inputs"]
)
print("Compile the filter and retrieve its FHE circuit ...")
compiler = Compiler(
numpy_filter_proxy,
{
parameters_mapping["inputs"]: "encrypted",
},
)
self.fhe_circuit = compiler.compile(inputset)
return self.fhe_circuit
def post_processing(self, output_result):
"""Apply post-processing to the decrypted output result.
Args:
output_result (np.ndarray): The decrypted result to post-process.
Returns:
output_result (np.ndarray): The post-processed result.
"""
print(f"{output_result=}")
return "PASS" if output_result[0] == 1 else "FAIL"
# matcher = Matcher(matcher_name=AVAILABLE_MATCHERS[0])
# fhe_circuit = matcher.compile()
|