Roman
chore: Add comments, clean unused objects and improve ridge detection
3cf0931 unverified
raw
history blame
12.4 kB
"Filter definitions, with pre-processing, post-processing and compilation methods."
import json
import numpy as np
import torch
from common import AVAILABLE_FILTERS, INPUT_SHAPE
from concrete.numpy.compilation.compiler import Compiler
from torch import nn
from concrete.ml.common.debugging.custom_assert import assert_true
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.onnx.convert import get_equivalent_numpy_forward
from concrete.ml.torch.numpy_module import NumpyModule
from concrete.ml.version import __version__ as CML_VERSION
class _TorchIdentity(nn.Module):
"""Torch identity model."""
def forward(self, x):
"""Identity forward pass.
Args:
x (torch.Tensor): The input image.
Returns:
x (torch.Tensor): The input image.
"""
return x
class _TorchInverted(nn.Module):
"""Torch inverted model."""
def forward(self, x):
"""Forward pass for inverting an image's colors.
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The (color) inverted image.
"""
return 255 - x
class _TorchRotate(nn.Module):
"""Torch rotated model."""
def forward(self, x):
"""Forward pass for rotating an image.
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The rotated image.
"""
return x.transpose(2, 3)
class _TorchConv2D(nn.Module):
"""Torch model for applying a single 2D convolution operator on images."""
def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
"""Initialize the filter.
Args:
kernel (np.ndarray): The convolution kernel to consider.
"""
super().__init__()
self.kernel = torch.tensor(kernel, dtype=torch.int64)
self.n_out_channels = n_out_channels
self.n_in_channels = n_in_channels
self.groups = groups
self.threshold = threshold
def forward(self, x):
"""Forward pass for filtering the image using a 2D kernel.
Args:
x (torch.Tensor): The input image.
Returns:
torch.Tensor: The filtered image.
"""
# Define the convolution parameters
stride = 1
kernel_shape = self.kernel.shape
# Ensure the kernel has a proper shape
# If the kernel has a 1D shape, a (1, 1) kernel is used for each in_channels
if len(kernel_shape) == 1:
kernel = self.kernel.reshape(
self.n_out_channels,
self.n_in_channels // self.groups,
1,
1,
)
# Else, if the kernel has a 2D shape, a single (Kw, Kh) kernel is used on all in_channels
elif len(kernel_shape) == 2:
kernel = self.kernel.expand(
self.n_out_channels,
self.n_in_channels // self.groups,
kernel_shape[0],
kernel_shape[1],
)
else:
raise ValueError(
"Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape "
f"{kernel_shape}"
)
# Apply the convolution
x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups)
# Subtract a given threshold if given
if self.threshold is not None:
x -= self.threshold
return x
class Filter:
"""Filter class used in the app."""
def __init__(self, image_filter="inverted"):
"""Initializing the filter class using a given filter.
Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing).
Args:
image_filter (str): The filter to consider. Default to "inverted".
"""
assert_true(
image_filter in AVAILABLE_FILTERS,
f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
f"but got {image_filter}",
)
self.filter = image_filter
self.onnx_model = None
self.fhe_circuit = None
self.divide = None
self.repeat_out_channels = False
if image_filter == "identity":
self.torch_model = _TorchIdentity()
elif image_filter == "inverted":
self.torch_model = _TorchInverted()
elif image_filter == "rotate":
self.torch_model = _TorchRotate()
elif image_filter == "black and white":
# Define the grayscale weights (RGB order)
# These weights were used in PAL and NTSC video systems and can be found at
# https://en.wikipedia.org/wiki/Grayscale
# There are initially supposed to be float weights (0.299, 0.587, 0.114), with
# 0.299 + 0.587 + 0.114 = 1
# However, since FHE computations require weights to be integers, we first multiply
# these by a factor of 1000. The output image's values are then divided by 1000 in
# post-processing in order to retrieve the correct result
kernel = [299, 587, 114]
self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
# Define the value used when for dividing the output values in post-processing
self.divide = 1000
# Indicate that the out_channels will need to be repeated, as Gradio requires all
# images to have a RGB format, even for grayscaled ones
self.repeat_out_channels = True
elif image_filter == "blur":
kernel = np.ones((3, 3))
self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
# Define the value used when for dividing the output values in post-processing
self.divide = 9
elif image_filter == "sharpen":
kernel = [
[0, -1, 0],
[-1, 5, -1],
[0, -1, 0],
]
self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
elif image_filter == "ridge detection":
kernel = [
[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1],
]
# Additionally to the convolution operator, the filter will subtract a given threshold
# value to the result in order to better display the ridges
self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1, threshold=900)
# Indicate that the out_channels will need to be repeated, as Gradio requires all
# images to have a RGB format, even for grayscaled ones. Ridge detection images are
# ususally displayed as such
self.repeat_out_channels = True
def compile(self, onnx_model=None):
"""Compile the model on a representative inputset.
Args:
onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
generated automatically using a NumpyModule. Default to None.
"""
# Generate a random representative set of images used for compilation, following Torch's
# shape format (batch, in_channels, image_height, image_width)
np.random.seed(42)
inputset = tuple(
np.random.randint(0, 255, size=((1, 3) + INPUT_SHAPE), dtype=np.int64) for _ in range(10)
)
# If no onnx model was given, generate a new one.
if onnx_model is None:
numpy_module = NumpyModule(
self.torch_model,
dummy_input=torch.from_numpy(inputset[0]),
)
onnx_model = numpy_module.onnx_model
# Get the proxy function and parameter mappings for initializing the compiler
self.onnx_model = onnx_model
numpy_filter = get_equivalent_numpy_forward(onnx_model)
numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
compiler = Compiler(
numpy_filter_proxy,
{parameters_mapping["inputs"]: "encrypted"},
)
# Compile the filter
self.fhe_circuit = compiler.compile(inputset)
return self.fhe_circuit
def pre_processing(self, input_image):
"""Apply pre-processing to the encrypted input images.
Args:
input_image (np.ndarray): The image to pre-process.
Returns:
input_image (np.ndarray): The pre-processed image.
"""
# Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
# the same shape conventions.
# Additionally, make sure the input images are made of integers only
input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64)
return input_image
def post_processing(self, output_image):
"""Apply post-processing to the encrypted output images.
Args:
input_image (np.ndarray): The decrypted image to post-process.
Returns:
input_image (np.ndarray): The post-processed image.
"""
# Divide all values if needed
if self.divide is not None:
output_image //= self.divide
# Clip the image's values to proper RGB standards as filters don't handle such constraints
output_image = output_image.clip(0, 255)
# Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
# the same shape conventions.
output_image = output_image.transpose(0, 2, 3, 1).squeeze(0)
# Gradio requires all images to follow a RGB format
if self.repeat_out_channels:
output_image = output_image.repeat(3, axis=2)
return output_image
@classmethod
def from_json(cls, json_path):
"""Instantiate a filter using a json file.
Args:
json_path (Union[str, pathlib.Path]): Path to the json file.
Returns:
model (Filter): The instantiated filter class.
"""
# Load the parameters from the json file
with open(json_path, "r", encoding="utf-8") as f:
serialized_processing = json.load(f)
# Make sure the version in serialized_model is the same as CML_VERSION
assert_true(
serialized_processing["cml_version"] == CML_VERSION,
f"The version of Concrete ML library ({CML_VERSION}) is different "
f"from the one used to save the model ({serialized_processing['cml_version']}). "
"Please update to the proper Concrete ML version.",
)
# Initialize the model
model = cls(image_filter=serialized_processing["model_filter"])
return model
def to_json(self, path_dir, file_name="serialized_processing"):
"""Export the parameters to a json file.
Args:
path_dir (Union[str, pathlib.Path]): The path to consider when saving the file.
file_name (str): The file name
"""
# Serialize the parameters
serialized_processing = {
"model_filter": self.filter,
}
serialized_processing = self._clean_dict_types_for_json(serialized_processing)
# Add the version of the current CML library
serialized_processing["cml_version"] = CML_VERSION
# Save the json file
with open(path_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
json.dump(serialized_processing, f)
def _clean_dict_types_for_json(self, d: dict) -> dict:
"""Clean all values in the dict to be json serializable.
Args:
d (Dict): The dict to clean
Returns:
Dict: The cleaned dict
"""
key_to_delete = []
for key, value in d.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
d[key] = [self._clean_dict_types_for_json(v) for v in value]
elif isinstance(value, dict):
d[key] = self._clean_dict_types_for_json(value)
elif isinstance(value, (np.generic, np.ndarray)):
d[key] = d[key].tolist()
for key in key_to_delete:
d.pop(key)
return d