"Filter definitions, with pre-processing, post-processing and compilation methods." import json import numpy as np import torch from common import AVAILABLE_FILTERS, INPUT_SHAPE from concrete.numpy.compilation.compiler import Compiler from torch import nn from concrete.ml.common.debugging.custom_assert import assert_true from concrete.ml.common.utils import generate_proxy_function from concrete.ml.onnx.convert import get_equivalent_numpy_forward from concrete.ml.torch.numpy_module import NumpyModule from concrete.ml.version import __version__ as CML_VERSION class _TorchIdentity(nn.Module): """Torch identity model.""" def forward(self, x): """Identity forward pass. Args: x (torch.Tensor): The input image. Returns: x (torch.Tensor): The input image. """ return x class _TorchInverted(nn.Module): """Torch inverted model.""" def forward(self, x): """Forward pass for inverting an image's colors. Args: x (torch.Tensor): The input image. Returns: torch.Tensor: The (color) inverted image. """ return 255 - x class _TorchRotate(nn.Module): """Torch rotated model.""" def forward(self, x): """Forward pass for rotating an image. Args: x (torch.Tensor): The input image. Returns: torch.Tensor: The rotated image. """ return x.transpose(2, 3) class _TorchConv2D(nn.Module): """Torch model for applying a single 2D convolution operator on images.""" def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None): """Initialize the filter. Args: kernel (np.ndarray): The convolution kernel to consider. """ super().__init__() self.kernel = torch.tensor(kernel, dtype=torch.int64) self.n_out_channels = n_out_channels self.n_in_channels = n_in_channels self.groups = groups self.threshold = threshold def forward(self, x): """Forward pass for filtering the image using a 2D kernel. Args: x (torch.Tensor): The input image. Returns: torch.Tensor: The filtered image. """ # Define the convolution parameters stride = 1 kernel_shape = self.kernel.shape # Ensure the kernel has a proper shape # If the kernel has a 1D shape, a (1, 1) kernel is used for each in_channels if len(kernel_shape) == 1: kernel = self.kernel.reshape( self.n_out_channels, self.n_in_channels // self.groups, 1, 1, ) # Else, if the kernel has a 2D shape, a single (Kw, Kh) kernel is used on all in_channels elif len(kernel_shape) == 2: kernel = self.kernel.expand( self.n_out_channels, self.n_in_channels // self.groups, kernel_shape[0], kernel_shape[1], ) else: raise ValueError( "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape " f"{kernel_shape}" ) # Apply the convolution x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups) # Subtract a given threshold if given if self.threshold is not None: x -= self.threshold return x class Filter: """Filter class used in the app.""" def __init__(self, image_filter="inverted"): """Initializing the filter class using a given filter. Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing). Args: image_filter (str): The filter to consider. Default to "inverted". """ assert_true( image_filter in AVAILABLE_FILTERS, f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, " f"but got {image_filter}", ) self.filter = image_filter self.onnx_model = None self.fhe_circuit = None self.divide = None self.repeat_out_channels = False if image_filter == "identity": self.torch_model = _TorchIdentity() elif image_filter == "inverted": self.torch_model = _TorchInverted() elif image_filter == "rotate": self.torch_model = _TorchRotate() elif image_filter == "black and white": # Define the grayscale weights (RGB order) # These weights were used in PAL and NTSC video systems and can be found at # https://en.wikipedia.org/wiki/Grayscale # There are initially supposed to be float weights (0.299, 0.587, 0.114), with # 0.299 + 0.587 + 0.114 = 1 # However, since FHE computations require weights to be integers, we first multiply # these by a factor of 1000. The output image's values are then divided by 1000 in # post-processing in order to retrieve the correct result kernel = [299, 587, 114] self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1) # Define the value used when for dividing the output values in post-processing self.divide = 1000 # Indicate that the out_channels will need to be repeated, as Gradio requires all # images to have a RGB format, even for grayscaled ones self.repeat_out_channels = True elif image_filter == "blur": kernel = np.ones((3, 3)) self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) # Define the value used when for dividing the output values in post-processing self.divide = 9 elif image_filter == "sharpen": kernel = [ [0, -1, 0], [-1, 5, -1], [0, -1, 0], ] self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) elif image_filter == "ridge detection": kernel = [ [-1, -1, -1], [-1, 9, -1], [-1, -1, -1], ] # Additionally to the convolution operator, the filter will subtract a given threshold # value to the result in order to better display the ridges self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1, threshold=900) # Indicate that the out_channels will need to be repeated, as Gradio requires all # images to have a RGB format, even for grayscaled ones. Ridge detection images are # ususally displayed as such self.repeat_out_channels = True def compile(self, onnx_model=None): """Compile the model on a representative inputset. Args: onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be generated automatically using a NumpyModule. Default to None. """ # Generate a random representative set of images used for compilation, following Torch's # shape format (batch, in_channels, image_height, image_width) np.random.seed(42) inputset = tuple( np.random.randint(0, 255, size=((1, 3) + INPUT_SHAPE), dtype=np.int64) for _ in range(10) ) # If no onnx model was given, generate a new one. if onnx_model is None: numpy_module = NumpyModule( self.torch_model, dummy_input=torch.from_numpy(inputset[0]), ) onnx_model = numpy_module.onnx_model # Get the proxy function and parameter mappings for initializing the compiler self.onnx_model = onnx_model numpy_filter = get_equivalent_numpy_forward(onnx_model) numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"]) compiler = Compiler( numpy_filter_proxy, {parameters_mapping["inputs"]: "encrypted"}, ) # Compile the filter self.fhe_circuit = compiler.compile(inputset) return self.fhe_circuit def pre_processing(self, input_image): """Apply pre-processing to the encrypted input images. Args: input_image (np.ndarray): The image to pre-process. Returns: input_image (np.ndarray): The pre-processed image. """ # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow # the same shape conventions. # Additionally, make sure the input images are made of integers only input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64) return input_image def post_processing(self, output_image): """Apply post-processing to the encrypted output images. Args: input_image (np.ndarray): The decrypted image to post-process. Returns: input_image (np.ndarray): The post-processed image. """ # Divide all values if needed if self.divide is not None: output_image //= self.divide # Clip the image's values to proper RGB standards as filters don't handle such constraints output_image = output_image.clip(0, 255) # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow # the same shape conventions. output_image = output_image.transpose(0, 2, 3, 1).squeeze(0) # Gradio requires all images to follow a RGB format if self.repeat_out_channels: output_image = output_image.repeat(3, axis=2) return output_image @classmethod def from_json(cls, json_path): """Instantiate a filter using a json file. Args: json_path (Union[str, pathlib.Path]): Path to the json file. Returns: model (Filter): The instantiated filter class. """ # Load the parameters from the json file with open(json_path, "r", encoding="utf-8") as f: serialized_processing = json.load(f) # Make sure the version in serialized_model is the same as CML_VERSION assert_true( serialized_processing["cml_version"] == CML_VERSION, f"The version of Concrete ML library ({CML_VERSION}) is different " f"from the one used to save the model ({serialized_processing['cml_version']}). " "Please update to the proper Concrete ML version.", ) # Initialize the model model = cls(image_filter=serialized_processing["model_filter"]) return model def to_json(self, path_dir, file_name="serialized_processing"): """Export the parameters to a json file. Args: path_dir (Union[str, pathlib.Path]): The path to consider when saving the file. file_name (str): The file name """ # Serialize the parameters serialized_processing = { "model_filter": self.filter, } serialized_processing = self._clean_dict_types_for_json(serialized_processing) # Add the version of the current CML library serialized_processing["cml_version"] = CML_VERSION # Save the json file with open(path_dir / f"{file_name}.json", "w", encoding="utf-8") as f: json.dump(serialized_processing, f) def _clean_dict_types_for_json(self, d: dict) -> dict: """Clean all values in the dict to be json serializable. Args: d (Dict): The dict to clean Returns: Dict: The cleaned dict """ key_to_delete = [] for key, value in d.items(): if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict): d[key] = [self._clean_dict_types_for_json(v) for v in value] elif isinstance(value, dict): d[key] = self._clean_dict_types_for_json(value) elif isinstance(value, (np.generic, np.ndarray)): d[key] = d[key].tolist() for key in key_to_delete: d.pop(key) return d