File size: 8,743 Bytes
341de97
 
 
 
 
 
7e8d9b9
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e8d9b9
341de97
 
 
 
7e8d9b9
 
 
 
 
341de97
 
7e8d9b9
 
341de97
 
1f125f1
 
 
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e8d9b9
 
 
341de97
 
 
 
 
 
 
 
 
7e8d9b9
 
 
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
1f125f1
7e8d9b9
341de97
7e8d9b9
341de97
 
 
 
 
 
 
7e8d9b9
 
 
 
 
 
 
341de97
11c7796
341de97
 
ee83d59
1f125f1
 
 
 
 
 
 
 
 
341de97
 
 
 
 
 
 
 
7e8d9b9
 
341de97
 
 
 
 
 
 
 
 
 
 
 
 
1f125f1
 
 
 
341de97
 
 
 
 
 
ee83d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c7796
 
ee83d59
11c7796
ee83d59
 
 
 
 
 
 
 
 
 
 
 
 
 
11c7796
ee83d59
 
11c7796
 
 
 
ee83d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c7796
341de97
 
 
 
 
 
 
 
 
 
7e8d9b9
341de97
 
 
 
ee83d59
341de97
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import os
from typing import Union

import torch
from transformers import LogitsProcessor

from seed_scheme_factory import SeedSchemeFactory
from utils import bytes_to_base, base_to_bytes, get_values_per_byte


class BaseProcessor(object):
    def __init__(
        self,
        msg_base: int,
        vocab: list[int],
        device: torch.device,
        seed_scheme: str,
        window_length: int = 1,
        salt_key: Union[int, None] = None,
        private_key: Union[int, None] = None,
    ):
        """
        Args:
            msg_base: base of the message.
            vocab: vocabulary list.
            device: device to load processor.
            seed_scheme: scheme used to compute the seed.
            window_length: length of window to compute the seed.
            salt_key: salt to add to the seed.
            private_key: private key used to compute the seed.
        """
        # Universal parameters
        self.msg_base = msg_base
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.device = device

        # Seed parameters
        seed_fn = SeedSchemeFactory.get_instance(
            seed_scheme,
            salt_key=salt_key,
            private_key=private_key,
        )
        if seed_fn is None:
            raise ValueError(f'Seed scheme "{seed_scheme}" is invalid')
        else:
            self.seed_fn = seed_fn

        self.window_length = window_length

        # Initialize RNG, always use cpu generator
        self.rng = torch.Generator(device="cpu")

        # Compute the ranges of each value in base
        self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64).to(
            self.device
        )
        chunk_size = self.vocab_size / self.msg_base
        r = self.vocab_size % self.msg_base
        self.ranges[1:] = chunk_size
        self.ranges[1 : r + 1] += 1
        self.ranges = torch.cumsum(self.ranges, dim=0)

    def _seed_rng(self, input_ids: torch.Tensor):
        """
        Set the seed for the rng based on the current sequences.

        Args:
            input_ids: id in the input sequence.
        """
        seed = self.seed_fn(input_ids[-self.window_length :])
        self.rng.manual_seed(seed)

    def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int):
        """
        Get ids of tokens in the valid list for the current sequences.
        """
        self._seed_rng(input_ids)
        vocab_perm = torch.randperm(
            self.vocab_size, generator=self.rng, device="cpu"
        ).to(self.device)
        vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]

        return vocab_list

    def _get_value(self, input_ids: torch.Tensor):
        """
        Check whether the token is in the valid list.
        """
        self._seed_rng(input_ids[:-1])
        vocab_perm = torch.randperm(
            self.vocab_size, generator=self.rng, device="cpu"
        ).to(self.device)

        cur_token = input_ids[-1]
        cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
        value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1

        return value


class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
    def __init__(
        self,
        prompt_ids: torch.Tensor,
        msg: bytes,
        gamma: float,
        tokenizer,
        start_pos: int = 0,
        *args,
        **kwargs,
    ):
        """
        Args:
            msg: message to hide in the text.
            gamma: bias add to scores of token in valid list.
        """
        super().__init__(*args, **kwargs)
        if prompt_ids.size(0) != 1:
            raise RuntimeError(
                "EncryptorLogitsProcessor does not support multiple prompts input."
            )

        self.prompt_size = prompt_ids.size(1)
        self.start_pos = start_pos

        self.raw_msg = msg
        self.msg = bytes_to_base(msg, self.msg_base)
        self.gamma = gamma
        self.tokenizer = tokenizer
        special_tokens = [
            tokenizer.bos_token_id,
            tokenizer.eos_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id,
            tokenizer.cls_token_id,
        ]
        special_tokens = [x for x in special_tokens if x is not None]
        self.special_tokens = torch.tensor(special_tokens, device=self.device)

    def __call__(
        self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
    ):
        # If the whole message is hidden already, then just return the raw scores.

        for i, input_ids in enumerate(input_ids_batch):
            cur_pos = input_ids.size(0)
            msg_ptr = cur_pos - (self.prompt_size + self.start_pos)
            if msg_ptr < 0 or msg_ptr >= len(self.msg):
                continue
            scores_batch[i] = self._add_bias_to_valid_list(
                input_ids, scores_batch[i], self.msg[msg_ptr]
            )

        return scores_batch

    def _add_bias_to_valid_list(
        self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
    ):
        """
        Add the bias (gamma) to the valid list tokens
        """
        ids = torch.cat(
            [self._get_valid_list_ids(input_ids, value), self.special_tokens]
        )

        scores[ids] = scores[ids] + self.gamma
        return scores

    def get_message_len(self):
        return len(self.msg)

    def __map_input_ids(self, input_ids: torch.Tensor, base_arr, byte_arr):
        byte_enc_msg = [-1 for _ in range(input_ids.size(0))]
        base_enc_msg = [-1 for _ in range(input_ids.size(0))]
        base_msg = [-1 for _ in range(input_ids.size(0))]
        byte_msg = [-1 for _ in range(input_ids.size(0))]

        values_per_byte = get_values_per_byte(self.msg_base)
        start = self.start_pos % values_per_byte

        for i, b in enumerate(base_arr):
            base_enc_msg[i] = base_arr[i]
            byte_enc_msg[i] = byte_arr[(i - start) // values_per_byte]

        for i, b in enumerate(self.msg):
            base_msg[i + self.start_pos] = b
            byte_msg[i + self.start_pos] = self.raw_msg[i // values_per_byte]

        return base_msg, byte_msg, base_enc_msg, byte_enc_msg

    def validate(self, input_ids_batch: torch.Tensor):
        res = []
        tokens_infos = []
        for input_ids in input_ids_batch:
            # Initialization
            base_arr = []

            # Loop and obtain values of all tokens
            for i in range(0, input_ids.size(0)):
                base_arr.append(self._get_value(input_ids[: i + 1]))

            values_per_byte = get_values_per_byte(self.msg_base)

            # Transform the values to bytes
            start = self.start_pos % values_per_byte
            byte_arr = base_to_bytes(base_arr[start:], self.msg_base)

            # Construct the
            cnt = 0
            enc_msg = byte_arr[self.start_pos // values_per_byte :]
            for i in range(min(len(enc_msg), len(self.raw_msg))):
                if self.raw_msg[i] == enc_msg[i]:
                    cnt += 1
            res.append(cnt / len(self.raw_msg))

            base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
                self.__map_input_ids(input_ids, base_arr, byte_arr)
            )
            tokens = []
            input_strs = [self.tokenizer.decode([input]) for input in input_ids]
            for i in range(len(base_enc_msg)):
                tokens.append(
                    {
                        "token": input_strs[i],
                        "base_enc": base_enc_msg[i],
                        "byte_enc": byte_enc_msg[i],
                        "base_msg": base_msg[i],
                        "byte_msg": byte_msg[i],
                        "byte_id": (i - start) // values_per_byte,
                    }
                )
            tokens_infos.append(tokens)

        return res, tokens_infos


class DecryptorProcessor(BaseProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def decrypt(self, input_ids_batch: torch.Tensor):
        """
        Decrypt the text sequences.
        """
        shift_msg = []
        for shift in range(get_values_per_byte(self.msg_base)):
            msg = []
            bytes_msg = []
            for i, input_ids in enumerate(input_ids_batch):
                msg.append(list())
                for j in range(shift, len(input_ids)):
                    # TODO: this could be slow. Considering reimplement this.
                    value = self._get_value(input_ids[: j + 1])
                    msg[i].append(value)

                bytes_msg.append(base_to_bytes(msg[i], self.msg_base))
            shift_msg.append(bytes_msg)

        return shift_msg