encrypted-anonymization / fhe_anonymizer.py
jfrery-zama's picture
initial commit
646bd9e
raw
history blame
No virus
2.56 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
self.embeddings_model = gensim.models.FastText.load(str(base_dir / "embedded_model.model"))
self.punctuation_list = punctuation_list
with open(base_dir / "cml_xgboost.model", "r") as model_file:
self.fhe_ner_detection = load(file=model_file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
text = self.preprocess_sentences(text)
identified_words = []
new_text = []
for word in text.split():
# Prediction for each word
x = self.embeddings_model.wv[word][None]
prediction = self.fhe_ner_detection.predict(x)
# prediction = self.fhe_inference(x).argmax(1)[0]
if prediction == 1:
identified_words.append(word)
new_text.append("<REMOVED>")
else:
new_text.append(word)
# Joining the modified text
modified_text = " ".join(new_text)
return modified_text, identified_words
def preprocess_sentences(self, sentence, verbose=False):
"""Preprocess the sentence."""
sentence = re.sub(r'\n+', ' ', sentence)
if verbose: print(sentence)
sentence = re.sub(' +', ' ', sentence)
if verbose: print(sentence)
sentence = re.sub(r"'s\b", " s", sentence)
if verbose: print(sentence)
sentence = re.sub(r'\s([,.!?;:])', r'\1', sentence)
if verbose: print(sentence)
pattern = r'(?<!\w)[{}]|[{}](?!\w)'.format(re.escape(self.punctuation_list), re.escape(self.punctuation_list))
sentence = re.sub(pattern, '', sentence)
if verbose: print(sentence)
sentence = re.sub(r'\s([,.!?;:])', r'\1', sentence)
if verbose: print(sentence)
return sentence