tnk2908 commited on
Commit
52c67ef
·
1 Parent(s): e4f5e06

Add statistical analysis

Browse files
Files changed (8) hide show
  1. analyse.py +408 -0
  2. api.py +3 -0
  3. config.ini +1 -0
  4. model_factory.py +42 -1
  5. requirements.txt +2 -0
  6. schemes.py +5 -0
  7. stegno.py +22 -11
  8. utils.py +1 -0
analyse.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from argparse import ArgumentParser
5
+
6
+ import numpy as np
7
+ from matplotlib import pyplot as plt
8
+
9
+ import torch
10
+ from datasets import load_dataset
11
+ from model_factory import ModelFactory
12
+ from stegno import generate
13
+
14
+ rng = torch.Generator(device="cpu")
15
+ rng.manual_seed(0)
16
+
17
+
18
+ def load_msgs(msg_lens: list[int], file: str | None = None):
19
+ msgs = None
20
+ if file is not None and os.path.isfile(file):
21
+ with open(file, "r") as f:
22
+ msgs = json.load(f)
23
+ if "readable" not in msgs and "random" not in msgs:
24
+ msgs = None
25
+ else:
26
+ return msgs
27
+
28
+ msgs = {
29
+ "readable": [],
30
+ "random": [],
31
+ }
32
+
33
+ c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
34
+ iterator = iter(c4_en)
35
+
36
+ for length in msg_lens:
37
+ random_msg = torch.randint(256, (length,), generator=rng)
38
+ base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode(
39
+ "ascii"
40
+ )
41
+ msgs["random"].append(base64_msg)
42
+
43
+ readable_msg = next(iterator)["text"]
44
+ while len(readable_msg) < length:
45
+ readable_msg = next(iterator)["text"]
46
+ msgs["readable"].append(readable_msg[:length])
47
+
48
+ return msgs
49
+
50
+
51
+ def load_prompts(n: int, min_length: int, file: str | None = None):
52
+ prompts = None
53
+ if file is not None and os.path.isfile(file):
54
+ with open(file, "r") as f:
55
+ prompts = json.load(f)
56
+ return prompts
57
+
58
+ prompts = []
59
+
60
+ c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
61
+ iterator = iter(c4_en)
62
+
63
+ while len(prompts) < n:
64
+ text = next(iterator)["text"]
65
+ if len(text) < min_length:
66
+ continue
67
+ prompts.append(text)
68
+
69
+ return prompts
70
+
71
+
72
+ def create_args():
73
+ parser = ArgumentParser()
74
+
75
+ # messages
76
+ parser.add_argument(
77
+ "--msgs-file", type=str, default=None, help="Where messages are stored"
78
+ )
79
+ parser.add_argument(
80
+ "--msgs-lengths",
81
+ nargs=3,
82
+ type=int,
83
+ help="Range of messages' lengths. This is parsed in form: <start> <end> <step>",
84
+ )
85
+ parser.add_argument(
86
+ "--msgs-per-length",
87
+ type=int,
88
+ default=5,
89
+ help="Number of messages per length",
90
+ )
91
+ # prompts
92
+ parser.add_argument(
93
+ "--prompts-file",
94
+ type=str,
95
+ default=None,
96
+ help="Where prompts are stored",
97
+ )
98
+ parser.add_argument(
99
+ "--num-prompts",
100
+ type=int,
101
+ default=500,
102
+ help="Number of prompts",
103
+ )
104
+ parser.add_argument(
105
+ "--prompt-size",
106
+ type=int,
107
+ default=50,
108
+ help="Size of prompts",
109
+ )
110
+ parser.add_argument(
111
+ "--prompts-min-length",
112
+ type=int,
113
+ default=100,
114
+ help="Min length of prompts",
115
+ )
116
+ # Others
117
+ parser.add_argument(
118
+ "--overwrite",
119
+ action="store_true",
120
+ help="Whether to overwrite prompts and messages files",
121
+ )
122
+
123
+ # Hyperparameters
124
+ parser.add_argument(
125
+ "--gen-model",
126
+ type=str,
127
+ default="gpt2",
128
+ help="Model used to generate",
129
+ )
130
+ parser.add_argument(
131
+ "--deltas",
132
+ nargs=3,
133
+ type=float,
134
+ help="Range of delta. This is parsed in form: <start> <end> <step>",
135
+ )
136
+ parser.add_argument(
137
+ "--bases",
138
+ nargs=3,
139
+ type=int,
140
+ help="Range of base. This is parsed in form: <start> <end> <step>",
141
+ )
142
+ parser.add_argument(
143
+ "--judge-model",
144
+ type=str,
145
+ default="gpt2",
146
+ help="Model used to compute score perplexity of generated text",
147
+ )
148
+ # Results
149
+ parser.add_argument(
150
+ "--repeat",
151
+ type=int,
152
+ default=1,
153
+ help="How many times to repeat for each set of parameters, prompts and messages",
154
+ )
155
+ parser.add_argument(
156
+ "--results-load-file",
157
+ type=str,
158
+ default=None,
159
+ help="Where to load results",
160
+ )
161
+ parser.add_argument(
162
+ "--results-save-file",
163
+ type=str,
164
+ default=None,
165
+ help="Where to save results",
166
+ )
167
+ parser.add_argument(
168
+ "--figs-dir",
169
+ type=str,
170
+ default=None,
171
+ help="Where to save figures",
172
+ )
173
+
174
+ return parser.parse_args()
175
+
176
+
177
+ def get_results(args, prompts, msgs):
178
+ model, tokenizer = ModelFactory.load_model(args.gen_model)
179
+ results = []
180
+
181
+ for prompt in prompts[:1]:
182
+ for delta in np.arange(
183
+ args.deltas[0], args.deltas[1] + args.deltas[2], args.deltas[2]
184
+ ):
185
+ for base in np.arange(
186
+ args.bases[0],
187
+ args.bases[1] + args.bases[2],
188
+ args.bases[2],
189
+ dtype=np.int32,
190
+ ):
191
+ for k in msgs:
192
+ msg_type = k
193
+ for msg in msgs[k]:
194
+ msg_bytes = (
195
+ msg.encode("ascii")
196
+ if k == "readable"
197
+ else base64.b64decode(msg)
198
+ )
199
+ for _ in range(args.repeat):
200
+ text, msg_rate, tokens_info = generate(
201
+ tokenizer=tokenizer,
202
+ model=model,
203
+ prompt=prompt,
204
+ msg=msg_bytes,
205
+ start_pos_p=[0],
206
+ delta=delta,
207
+ msg_base=base,
208
+ seed_scheme="sha_left_hash",
209
+ window_length=1,
210
+ private_key=0,
211
+ min_new_tokens_ratio=1,
212
+ max_new_tokens_ratio=2,
213
+ num_beams=4,
214
+ repetition_penalty=1.5,
215
+ prompt_size=args.prompt_size,
216
+ )
217
+ results.append(
218
+ {
219
+ "msg_type": msg_type,
220
+ "delta": delta.item(),
221
+ "base": base.item(),
222
+ "perplexity": ModelFactory.compute_perplexity(
223
+ args.judge_model, text
224
+ ),
225
+ "msg_rate": msg_rate,
226
+ }
227
+ )
228
+ return results
229
+
230
+
231
+ def process_results(results, save_dir):
232
+ data = {
233
+ "perplexities": {
234
+ "random": {},
235
+ "readable": {},
236
+ },
237
+ "msg_rates": {
238
+ "random": {},
239
+ "readable": {},
240
+ },
241
+ }
242
+ for r in results:
243
+ msg_type = r["msg_type"]
244
+ base = r["base"]
245
+ delta = r["delta"]
246
+ msg_rate = r["msg_rate"]
247
+ perplexity = r["perplexity"]
248
+
249
+ if (base, delta) not in data["msg_rates"][msg_type]:
250
+ data["msg_rates"][msg_type][(base, delta)] = []
251
+ data["msg_rates"][msg_type][(base, delta)].append(msg_rate)
252
+
253
+ if (base, delta) not in data["perplexities"][msg_type]:
254
+ data["perplexities"][msg_type][(base, delta)] = []
255
+ data["perplexities"][msg_type][(base, delta)].append(perplexity)
256
+
257
+ bases = {
258
+ "perplexities": {
259
+ "random": [],
260
+ "readable": [],
261
+ },
262
+ "msg_rates": {
263
+ "random": [],
264
+ "readable": [],
265
+ },
266
+ }
267
+ deltas = {
268
+ "perplexities": {
269
+ "random": [],
270
+ "readable": [],
271
+ },
272
+ "msg_rates": {
273
+ "random": [],
274
+ "readable": [],
275
+ },
276
+ }
277
+ values = {
278
+ "perplexities": {
279
+ "random": [],
280
+ "readable": [],
281
+ },
282
+ "msg_rates": {
283
+ "random": [],
284
+ "readable": [],
285
+ },
286
+ }
287
+ base_set = set()
288
+ delta_set = set()
289
+ for metric in data:
290
+ for msg_type in data[metric]:
291
+ for k in data[metric][msg_type]:
292
+ s = sum(data[metric][msg_type][k])
293
+ cnt = len(data[metric][msg_type][k])
294
+ data[metric][msg_type][k] = s / cnt
295
+
296
+ bases[metric][msg_type].append(k[0])
297
+ deltas[metric][msg_type].append(k[1])
298
+ values[metric][msg_type].append(s / cnt)
299
+ base_set.add(k[0])
300
+ delta_set.add(k[1])
301
+ for metric in data:
302
+ for msg_type in data[metric]:
303
+ bases[metric][msg_type] = np.array(bases[metric][msg_type], dtype=np.int32)
304
+ deltas[metric][msg_type] = np.array(deltas[metric][msg_type], dtype=np.int32)
305
+ values[metric][msg_type] = np.array(values[metric][msg_type], dtype=np.float32)
306
+
307
+ os.makedirs(save_dir, exist_ok=True)
308
+ for metric in data:
309
+ for msg_type in data[metric]:
310
+ fig = plt.figure(dpi=300)
311
+ s = lambda x: 3.0 + x * (3 if metric == "msg_rates" else 0.1)
312
+ plt.scatter(
313
+ bases[metric][msg_type],
314
+ deltas[metric][msg_type],
315
+ s(values[metric][msg_type]),
316
+ )
317
+ plt.savefig(
318
+ os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"),
319
+ bbox_inches="tight",
320
+ )
321
+
322
+ os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
323
+ for metric in data:
324
+ for msg_type in data[metric]:
325
+ for base_value in base_set:
326
+ mask = bases[metric][msg_type] == base_value
327
+ fig = plt.figure(dpi=300)
328
+ s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0)
329
+ plt.plot(
330
+ deltas[metric][msg_type][mask],
331
+ values[metric][msg_type][mask],
332
+ )
333
+ plt.savefig(
334
+ os.path.join(save_dir, f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf"),
335
+ bbox_inches="tight",
336
+ )
337
+ os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True)
338
+ for metric in data:
339
+ for msg_type in data[metric]:
340
+ for delta_value in delta_set:
341
+ mask = deltas[metric][msg_type] == delta_value
342
+ fig = plt.figure(dpi=300)
343
+ s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0)
344
+ plt.plot(
345
+ bases[metric][msg_type][mask],
346
+ values[metric][msg_type][mask],
347
+ )
348
+ plt.savefig(
349
+ os.path.join(save_dir, f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf"),
350
+ bbox_inches="tight",
351
+ )
352
+
353
+
354
+ def main(args):
355
+ prompts = load_prompts(
356
+ args.num_prompts,
357
+ args.prompts_min_length,
358
+ args.prompts_file if not args.overwrite else None,
359
+ )
360
+
361
+ msgs_lens = []
362
+ for i in np.arange(
363
+ args.msgs_lengths[0],
364
+ args.msgs_lengths[1] + args.msgs_lengths[2],
365
+ args.msgs_lengths[2],
366
+ dtype=np.int32,
367
+ ):
368
+ for _ in range(args.msgs_per_length):
369
+ msgs_lens.append(i)
370
+
371
+ msgs = load_msgs(
372
+ msgs_lens,
373
+ args.msgs_file if not args.overwrite else None,
374
+ )
375
+
376
+ if args.msgs_file:
377
+ if not os.path.isfile(args.msgs_file) or args.overwrite:
378
+ os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True)
379
+ with open(args.msgs_file, "w") as f:
380
+ json.dump(msgs, f)
381
+ print(f"Saved messages to {args.msgs_file}")
382
+ if args.prompts_file:
383
+ if not os.path.isfile(args.prompts_file) or args.overwrite:
384
+ os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True)
385
+ with open(args.prompts_file, "w") as f:
386
+ json.dump(prompts, f)
387
+ print(f"Saved prompts to {args.prompts_file}")
388
+
389
+ if args.results_load_file:
390
+ with open(args.results_load_file, "r") as f:
391
+ results = json.load(f)
392
+ else:
393
+ results = get_results(args, prompts, msgs)
394
+
395
+ if args.results_save_file:
396
+ os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True)
397
+ with open(args.results_save_file, "w") as f:
398
+ json.dump(results, f)
399
+ print(f"Saved results to {args.results_save_file}")
400
+
401
+ if args.figs_dir:
402
+ process_results(results, args.figs_dir)
403
+
404
+
405
+
406
+ if __name__ == "__main__":
407
+ args = create_args()
408
+ main(args)
api.py CHANGED
@@ -108,6 +108,9 @@ async def default_config():
108
  "private_key": GlobalConfig.get(
109
  "encrypt.default", "private_key"
110
  ),
 
 
 
111
  "max_new_tokens_ratio": GlobalConfig.get(
112
  "encrypt.default", "max_new_tokens_ratio"
113
  ),
 
108
  "private_key": GlobalConfig.get(
109
  "encrypt.default", "private_key"
110
  ),
111
+ "min_new_tokens_ratio": GlobalConfig.get(
112
+ "encrypt.default", "min_new_tokens_ratio"
113
+ ),
114
  "max_new_tokens_ratio": GlobalConfig.get(
115
  "encrypt.default", "max_new_tokens_ratio"
116
  ),
config.ini CHANGED
@@ -32,6 +32,7 @@ msg_base = int:2
32
  seed_scheme = str:sha_left_hash
33
  window_length = int:1
34
  private_key = int:0
 
35
  max_new_tokens_ratio = float:2.0
36
  num_beams = int:4
37
  repetition_penalty = float:1.0
 
32
  seed_scheme = str:sha_left_hash
33
  window_length = int:1
34
  private_key = int:0
35
+ min_new_tokens_ratio = float:1.0
36
  max_new_tokens_ratio = float:2.0
37
  num_beams = int:4
38
  repetition_penalty = float:1.0
model_factory.py CHANGED
@@ -63,7 +63,8 @@ class ModelFactory:
63
  @classmethod
64
  def load_model(cls, name):
65
  if name not in cls.models:
66
- cls.__load_model(name)
 
67
 
68
  if name != cls.run_model and cls.run_model is not None:
69
  cls.models[cls.run_model].to(cls.load_device)
@@ -83,3 +84,43 @@ class ModelFactory:
83
  return cls.tokenizers[name].model_max_length
84
  else:
85
  return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @classmethod
64
  def load_model(cls, name):
65
  if name not in cls.models:
66
+ if cls.__load_model(name) is None:
67
+ return None, None
68
 
69
  if name != cls.run_model and cls.run_model is not None:
70
  cls.models[cls.run_model].to(cls.load_device)
 
84
  return cls.tokenizers[name].model_max_length
85
  else:
86
  return 0
87
+
88
+ @classmethod
89
+ def compute_perplexity(cls, model_name, text):
90
+ # This code is copied from https://huggingface.co/docs/transformers/perplexity
91
+ model, tokenizer = cls.load_model(model_name)
92
+ if model is None or tokenizer is None:
93
+ return 0
94
+ device = model.device
95
+ encodings = tokenizer(text, return_tensors="pt").to(device)
96
+
97
+ max_length = model.config.n_positions
98
+ stride = max_length//2
99
+ seq_len = encodings.input_ids.size(1)
100
+
101
+ nlls = []
102
+ prev_end_loc = 0
103
+ for begin_loc in range(0, seq_len, stride):
104
+ end_loc = min(begin_loc + max_length, seq_len)
105
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
106
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
107
+ target_ids = input_ids.clone()
108
+ target_ids[:, :-trg_len] = -100
109
+
110
+ with torch.no_grad():
111
+ outputs = model(input_ids, labels=target_ids)
112
+
113
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
114
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
115
+ # to the left by 1.
116
+ neg_log_likelihood = outputs.loss
117
+
118
+ nlls.append(neg_log_likelihood)
119
+
120
+ prev_end_loc = end_loc
121
+ if end_loc == seq_len:
122
+ break
123
+
124
+ ppl = torch.exp(torch.stack(nlls).mean()).item()
125
+ return ppl
126
+
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  numpy==1.26.4
2
  tqdm==4.66.4
3
  transformers==4.41.2
 
4
  PyYAML==6.0.1
5
  scikit-learn==1.5.0
6
  torch==2.3.0
@@ -8,3 +9,4 @@ cryptography==42.0.8
8
  fastapi
9
  gradio
10
  uvicorn
 
 
1
  numpy==1.26.4
2
  tqdm==4.66.4
3
  transformers==4.41.2
4
+ datasets==2.20.0
5
  PyYAML==6.0.1
6
  scikit-learn==1.5.0
7
  torch==2.3.0
 
9
  fastapi
10
  gradio
11
  uvicorn
12
+ matplotlib==3.9.1
schemes.py CHANGED
@@ -49,6 +49,11 @@ class EncryptionBody(BaseModel):
49
  title="Private key used to compute the seed for PRF",
50
  ge=0,
51
  )
 
 
 
 
 
52
  max_new_tokens_ratio: float = Field(
53
  default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
54
  title="Max length of generated text compared to the minimum length required to hide the message",
 
49
  title="Private key used to compute the seed for PRF",
50
  ge=0,
51
  )
52
+ max_new_tokens_ratio: float = Field(
53
+ default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
54
+ title="Min length of generated text compared to the minimum length required to hide the message",
55
+ ge=1,
56
+ )
57
  max_new_tokens_ratio: float = Field(
58
  default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
59
  title="Max length of generated text compared to the minimum length required to hide the message",
stegno.py CHANGED
@@ -18,9 +18,11 @@ def generate(
18
  window_length: int = 1,
19
  salt_key: Union[int, None] = None,
20
  private_key: Union[int, None] = None,
 
21
  max_new_tokens_ratio: float = 2,
22
  num_beams: int = 4,
23
  repetition_penalty: float = 1.0,
 
24
  ):
25
  """
26
  Generate the sequence containing the hidden data.
@@ -36,7 +38,6 @@ def generate(
36
  window_length: length of window to compute the seed.
37
  salt_key: salt to add to the seed.
38
  private_key: private key used to compute the seed.
39
-
40
  """
41
  if len(start_pos_p) == 1:
42
  start_pos = start_pos_p[0]
@@ -47,9 +48,10 @@ def generate(
47
  start_pos = int(start_pos) + window_length
48
 
49
  tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
50
- prompt_size = tokenized_input.input_ids.size(1)
 
51
  logits_processor = EncryptorLogitsProcessor(
52
- prompt_ids=tokenized_input.input_ids,
53
  msg=msg,
54
  start_pos=start_pos,
55
  delta=delta,
@@ -62,14 +64,21 @@ def generate(
62
  salt_key=salt_key,
63
  private_key=private_key,
64
  )
65
- min_length = prompt_size + start_pos + logits_processor.get_message_len()
66
- max_length = prompt_size + int(
67
- start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
 
 
 
 
 
 
68
  )
69
  max_length = min(max_length, tokenizer.model_max_length)
70
  min_length = min(min_length, max_length)
71
  output_tokens = model.generate(
72
- **tokenized_input,
 
73
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
74
  min_length=min_length,
75
  max_length=max_length,
@@ -79,10 +88,12 @@ def generate(
79
  )
80
 
81
  output_tokens = output_tokens[:, prompt_size:]
82
- output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
83
- output_tokens_post = tokenizer(output_text, return_tensors="pt", add_special_tokens=False).to(
84
- model.device
85
- )
 
 
86
  msg_rates, tokens_infos = logits_processor.validate(
87
  output_tokens_post.input_ids
88
  )
 
18
  window_length: int = 1,
19
  salt_key: Union[int, None] = None,
20
  private_key: Union[int, None] = None,
21
+ min_new_tokens_ratio: float = 1,
22
  max_new_tokens_ratio: float = 2,
23
  num_beams: int = 4,
24
  repetition_penalty: float = 1.0,
25
+ prompt_size: int = -1,
26
  ):
27
  """
28
  Generate the sequence containing the hidden data.
 
38
  window_length: length of window to compute the seed.
39
  salt_key: salt to add to the seed.
40
  private_key: private key used to compute the seed.
 
41
  """
42
  if len(start_pos_p) == 1:
43
  start_pos = start_pos_p[0]
 
48
  start_pos = int(start_pos) + window_length
49
 
50
  tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
51
+ if prompt_size == -1:
52
+ prompt_size = tokenized_input.input_ids.size(1)
53
  logits_processor = EncryptorLogitsProcessor(
54
+ prompt_ids=tokenized_input.input_ids[:prompt_size],
55
  msg=msg,
56
  start_pos=start_pos,
57
  delta=delta,
 
64
  salt_key=salt_key,
65
  private_key=private_key,
66
  )
67
+ min_length = (
68
+ prompt_size
69
+ + start_pos
70
+ + logits_processor.get_message_len() * min_new_tokens_ratio
71
+ )
72
+ max_length = (
73
+ prompt_size
74
+ + start_pos
75
+ + logits_processor.get_message_len() * max_new_tokens_ratio
76
  )
77
  max_length = min(max_length, tokenizer.model_max_length)
78
  min_length = min(min_length, max_length)
79
  output_tokens = model.generate(
80
+ input_ids=tokenized_input.input_ids[:, :prompt_size],
81
+ attention_mask=tokenized_input.attention_mask[:, :prompt_size],
82
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
83
  min_length=min_length,
84
  max_length=max_length,
 
88
  )
89
 
90
  output_tokens = output_tokens[:, prompt_size:]
91
+ output_text = tokenizer.batch_decode(
92
+ output_tokens, skip_special_tokens=True
93
+ )[0]
94
+ output_tokens_post = tokenizer(
95
+ output_text, return_tensors="pt", add_special_tokens=False
96
+ ).to(model.device)
97
  msg_rates, tokens_infos = logits_processor.validate(
98
  output_tokens_post.input_ids
99
  )
utils.py CHANGED
@@ -55,3 +55,4 @@ def static_init(cls):
55
  if getattr(cls, "__static_init__", None):
56
  cls.__static_init__()
57
  return cls
 
 
55
  if getattr(cls, "__static_init__", None):
56
  cls.__static_init__()
57
  return cls
58
+