File size: 2,563 Bytes
646bd9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load

base_dir = Path(__file__).parent

class FHEAnonymizer:
    def __init__(self, punctuation_list=".,!?:;"):

        self.embeddings_model = gensim.models.FastText.load(str(base_dir / "embedded_model.model"))
        self.punctuation_list = punctuation_list
        with open(base_dir / "cml_xgboost.model", "r") as model_file:
            self.fhe_ner_detection = load(file=model_file)

        path_to_model = (base_dir / "deployment").resolve()
        self.client = FHEModelClient(path_to_model)
        self.server = FHEModelServer(path_to_model)
        self.client.generate_private_and_evaluation_keys()
        self.evaluation_key = self.client.get_serialized_evaluation_keys()

    def fhe_inference(self, x):
        enc_x = self.client.quantize_encrypt_serialize(x)
        enc_y = self.server.run(enc_x, self.evaluation_key)
        y = self.client.deserialize_decrypt_dequantize(enc_y)
        return y

    def __call__(self, text: str):
        text = self.preprocess_sentences(text)
        identified_words = []
        new_text = []

        for word in text.split():
            # Prediction for each word
            x = self.embeddings_model.wv[word][None]
            prediction = self.fhe_ner_detection.predict(x)
            # prediction = self.fhe_inference(x).argmax(1)[0]

            if prediction == 1:
                identified_words.append(word)
                new_text.append("<REMOVED>")
            else:
                new_text.append(word)

        # Joining the modified text
        modified_text = " ".join(new_text)

        return modified_text, identified_words

    def preprocess_sentences(self, sentence, verbose=False):
        """Preprocess the sentence."""

        sentence = re.sub(r'\n+', ' ', sentence)
        if verbose: print(sentence)

        sentence = re.sub(' +', ' ', sentence)
        if verbose: print(sentence)

        sentence = re.sub(r"'s\b", " s", sentence)
        if verbose: print(sentence)

        sentence = re.sub(r'\s([,.!?;:])', r'\1', sentence)
        if verbose: print(sentence)

        pattern = r'(?<!\w)[{}]|[{}](?!\w)'.format(re.escape(self.punctuation_list), re.escape(self.punctuation_list))
        sentence = re.sub(pattern, '', sentence)
        if verbose: print(sentence)

        sentence = re.sub(r'\s([,.!?;:])', r'\1', sentence)
        if verbose: print(sentence)


        return sentence