File size: 4,378 Bytes
174cd37
646bd9e
174cd37
 
646bd9e
174cd37
 
 
 
646bd9e
174cd37
646bd9e
174cd37
 
 
 
 
 
646bd9e
2b591f4
646bd9e
174cd37
646bd9e
d0b1031
1dfccc3
 
 
174cd37
 
 
 
 
 
 
646bd9e
174cd37
df6182e
174cd37
646bd9e
174cd37
 
646bd9e
174cd37
 
 
646bd9e
174cd37
 
646bd9e
174cd37
df6182e
174cd37
 
df6182e
 
174cd37
df6182e
174cd37
646bd9e
 
174cd37
 
 
1dfccc3
174cd37
646bd9e
174cd37
b160148
174cd37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df6182e
174cd37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646bd9e
d0b1031
174cd37
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import re
import time
import uuid
from pathlib import Path

from transformers import AutoModel, AutoTokenizer
from utils_demo import *

from concrete.ml.common.serialization.loaders import load
from concrete.ml.deployment import FHEModelClient, FHEModelServer

TOLERANCE_PROBA = 0.77

CURRENT_DIR = Path(__file__).parent

DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"


class FHEAnonymizer:
    def __init__(self):

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
        self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")

        self.punctuation_list = PUNCTUATION_LIST
        self.uuid_map = read_json(MAPPING_UUID_PATH)

        self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR)
        self.server = FHEModelServer(DEPLOYMENT_DIR)

    def generate_key(self):

        clean_directory()

        # Creates the private and evaluation keys on the client side
        self.client.generate_private_and_evaluation_keys()

        # Get the serialized evaluation keys
        self.evaluation_key = self.client.get_serialized_evaluation_keys()
        assert isinstance(self.evaluation_key, bytes)

        evaluation_key_path = KEYS_DIR / "evaluation_key"

        with evaluation_key_path.open("wb") as f:
            f.write(self.evaluation_key)

    def encrypt_query(self, text: str):
        # Pattern to identify words and non-words (including punctuation, spaces, etc.)
        tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
        encrypted_tokens = []

        for token in tokens:
            if bool(re.match(r"^\s+$", token)):
                continue
            # Directly append non-word tokens or whitespace to processed_tokens

            # Prediction for each word
            emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
            encrypted_x = self.client.quantize_encrypt_serialize(emb_x)
            assert isinstance(encrypted_x, bytes)

            encrypted_tokens.append(encrypted_x)

        write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens)

    def run_server(self):

        encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query")

        encrypted_output, timing = [], []
        for enc_x in encrypted_tokens:
            start_time = time.time()
            enc_y = self.server.run(enc_x, self.evaluation_key)
            timing.append((time.time() - start_time) / 60.0)
            encrypted_output.append(enc_y)

        write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output)
        write_pickle(KEYS_DIR / f"encrypted_timing", timing)

        return encrypted_output, timing

    def decrypt_output(self, text):

        encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output")

        tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
        decrypted_output, identified_words_with_prob = [], []

        i = 0
        for token in tokens:
            # Directly append non-word tokens or whitespace to processed_tokens
            if bool(re.match(r"^\s+$", token)):
                continue
            else:
                encrypted_token = encrypted_output[i]
                prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token)
                probability = prediction_proba[0][1]
                i += 1

                if probability >= TOLERANCE_PROBA:
                    identified_words_with_prob.append((token, probability))

                    # Use the existing UUID if available, otherwise generate a new one
                    tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
                    decrypted_output.append(tmp_uuid)
                    self.uuid_map[token] = tmp_uuid
                else:
                    decrypted_output.append(token)

            # Update the UUID map with query.
            with open(MAPPING_UUID_PATH, "w") as file:
                json.dump(self.uuid_map, file)

        write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output))
        write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob)


    def run_server_and_decrypt_output(self, text):
        self.run_server()
        self.decrypt_output(text)