tnk2908 commited on
Commit
9da31aa
1 Parent(s): 52c67ef

Update statistical analysis

Browse files
Files changed (2) hide show
  1. analyse.py +90 -73
  2. stegno.py +1 -0
analyse.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import base64
4
  from argparse import ArgumentParser
5
 
 
6
  import numpy as np
7
  from matplotlib import pyplot as plt
8
 
@@ -33,7 +34,7 @@ def load_msgs(msg_lens: list[int], file: str | None = None):
33
  c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
34
  iterator = iter(c4_en)
35
 
36
- for length in msg_lens:
37
  random_msg = torch.randint(256, (length,), generator=rng)
38
  base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode(
39
  "ascii"
@@ -48,7 +49,7 @@ def load_msgs(msg_lens: list[int], file: str | None = None):
48
  return msgs
49
 
50
 
51
- def load_prompts(n: int, min_length: int, file: str | None = None):
52
  prompts = None
53
  if file is not None and os.path.isfile(file):
54
  with open(file, "r") as f:
@@ -60,11 +61,13 @@ def load_prompts(n: int, min_length: int, file: str | None = None):
60
  c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
61
  iterator = iter(c4_en)
62
 
63
- while len(prompts) < n:
64
- text = next(iterator)["text"]
65
- if len(text) < min_length:
66
- continue
67
- prompts.append(text)
 
 
68
 
69
  return prompts
70
 
@@ -80,7 +83,7 @@ def create_args():
80
  "--msgs-lengths",
81
  nargs=3,
82
  type=int,
83
- help="Range of messages' lengths. This is parsed in form: <start> <end> <step>",
84
  )
85
  parser.add_argument(
86
  "--msgs-per-length",
@@ -105,13 +108,7 @@ def create_args():
105
  "--prompt-size",
106
  type=int,
107
  default=50,
108
- help="Size of prompts",
109
- )
110
- parser.add_argument(
111
- "--prompts-min-length",
112
- type=int,
113
- default=100,
114
- help="Min length of prompts",
115
  )
116
  # Others
117
  parser.add_argument(
@@ -131,13 +128,13 @@ def create_args():
131
  "--deltas",
132
  nargs=3,
133
  type=float,
134
- help="Range of delta. This is parsed in form: <start> <end> <step>",
135
  )
136
  parser.add_argument(
137
  "--bases",
138
- nargs=3,
139
  type=int,
140
- help="Range of base. This is parsed in form: <start> <end> <step>",
141
  )
142
  parser.add_argument(
143
  "--judge-model",
@@ -177,54 +174,59 @@ def create_args():
177
  def get_results(args, prompts, msgs):
178
  model, tokenizer = ModelFactory.load_model(args.gen_model)
179
  results = []
 
 
 
 
 
 
 
180
 
181
- for prompt in prompts[:1]:
182
- for delta in np.arange(
183
- args.deltas[0], args.deltas[1] + args.deltas[2], args.deltas[2]
184
- ):
185
- for base in np.arange(
186
- args.bases[0],
187
- args.bases[1] + args.bases[2],
188
- args.bases[2],
189
- dtype=np.int32,
190
  ):
191
- for k in msgs:
192
- msg_type = k
193
- for msg in msgs[k]:
194
- msg_bytes = (
195
- msg.encode("ascii")
196
- if k == "readable"
197
- else base64.b64decode(msg)
198
- )
199
- for _ in range(args.repeat):
200
- text, msg_rate, tokens_info = generate(
201
- tokenizer=tokenizer,
202
- model=model,
203
- prompt=prompt,
204
- msg=msg_bytes,
205
- start_pos_p=[0],
206
- delta=delta,
207
- msg_base=base,
208
- seed_scheme="sha_left_hash",
209
- window_length=1,
210
- private_key=0,
211
- min_new_tokens_ratio=1,
212
- max_new_tokens_ratio=2,
213
- num_beams=4,
214
- repetition_penalty=1.5,
215
- prompt_size=args.prompt_size,
216
- )
217
- results.append(
218
- {
219
- "msg_type": msg_type,
220
- "delta": delta.item(),
221
- "base": base.item(),
222
- "perplexity": ModelFactory.compute_perplexity(
223
- args.judge_model, text
224
- ),
225
- "msg_rate": msg_rate,
226
- }
227
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  return results
229
 
230
 
@@ -300,9 +302,15 @@ def process_results(results, save_dir):
300
  delta_set.add(k[1])
301
  for metric in data:
302
  for msg_type in data[metric]:
303
- bases[metric][msg_type] = np.array(bases[metric][msg_type], dtype=np.int32)
304
- deltas[metric][msg_type] = np.array(deltas[metric][msg_type], dtype=np.int32)
305
- values[metric][msg_type] = np.array(values[metric][msg_type], dtype=np.float32)
 
 
 
 
 
 
306
 
307
  os.makedirs(save_dir, exist_ok=True)
308
  for metric in data:
@@ -318,6 +326,7 @@ def process_results(results, save_dir):
318
  os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"),
319
  bbox_inches="tight",
320
  )
 
321
 
322
  os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
323
  for metric in data:
@@ -331,9 +340,14 @@ def process_results(results, save_dir):
331
  values[metric][msg_type][mask],
332
  )
333
  plt.savefig(
334
- os.path.join(save_dir, f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf"),
 
 
 
335
  bbox_inches="tight",
336
  )
 
 
337
  os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True)
338
  for metric in data:
339
  for msg_type in data[metric]:
@@ -346,23 +360,27 @@ def process_results(results, save_dir):
346
  values[metric][msg_type][mask],
347
  )
348
  plt.savefig(
349
- os.path.join(save_dir, f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf"),
 
 
 
350
  bbox_inches="tight",
351
  )
 
352
 
353
 
354
  def main(args):
355
  prompts = load_prompts(
356
  args.num_prompts,
357
- args.prompts_min_length,
358
  args.prompts_file if not args.overwrite else None,
359
  )
360
 
361
  msgs_lens = []
362
- for i in np.arange(
363
  args.msgs_lengths[0],
364
- args.msgs_lengths[1] + args.msgs_lengths[2],
365
- args.msgs_lengths[2],
366
  dtype=np.int32,
367
  ):
368
  for _ in range(args.msgs_per_length):
@@ -402,7 +420,6 @@ def main(args):
402
  process_results(results, args.figs_dir)
403
 
404
 
405
-
406
  if __name__ == "__main__":
407
  args = create_args()
408
  main(args)
 
3
  import base64
4
  from argparse import ArgumentParser
5
 
6
+ from tqdm import tqdm
7
  import numpy as np
8
  from matplotlib import pyplot as plt
9
 
 
34
  c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True)
35
  iterator = iter(c4_en)
36
 
37
+ for length in tqdm(msg_lens, desc="Loading messages"):
38
  random_msg = torch.randint(256, (length,), generator=rng)
39
  base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode(
40
  "ascii"
 
49
  return msgs
50
 
51
 
52
+ def load_prompts(n: int, prompt_size: int, file: str | None = None):
53
  prompts = None
54
  if file is not None and os.path.isfile(file):
55
  with open(file, "r") as f:
 
61
  c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True)
62
  iterator = iter(c4_en)
63
 
64
+ with tqdm(total=n, desc="Loading prompts") as pbar:
65
+ while len(prompts) < n:
66
+ text = next(iterator)["text"]
67
+ if len(text) < prompt_size:
68
+ continue
69
+ prompts.append(text)
70
+ pbar.update()
71
 
72
  return prompts
73
 
 
83
  "--msgs-lengths",
84
  nargs=3,
85
  type=int,
86
+ help="Range of messages' lengths. This is parsed in form: <start> <end> <num>",
87
  )
88
  parser.add_argument(
89
  "--msgs-per-length",
 
108
  "--prompt-size",
109
  type=int,
110
  default=50,
111
+ help="Size of prompts (in tokens)",
 
 
 
 
 
 
112
  )
113
  # Others
114
  parser.add_argument(
 
128
  "--deltas",
129
  nargs=3,
130
  type=float,
131
+ help="Range of delta. This is parsed in form: <start> <end> <num>",
132
  )
133
  parser.add_argument(
134
  "--bases",
135
+ nargs="+",
136
  type=int,
137
+ help="Bases used in base encoding",
138
  )
139
  parser.add_argument(
140
  "--judge-model",
 
174
  def get_results(args, prompts, msgs):
175
  model, tokenizer = ModelFactory.load_model(args.gen_model)
176
  results = []
177
+ total_gen = (
178
+ len(prompts)
179
+ * int(args.deltas[2])
180
+ * len(args.bases)
181
+ * args.repeat
182
+ * sum([len(msgs[k]) for k in msgs])
183
+ )
184
 
185
+ with tqdm(total=total_gen, desc="Generating") as pbar:
186
+ for prompt in prompts:
187
+ for delta in np.linspace(
188
+ args.deltas[0], args.deltas[1], int(args.deltas[2])
 
 
 
 
 
189
  ):
190
+ for base in args.bases:
191
+ for k in msgs:
192
+ msg_type = k
193
+ for msg in msgs[k]:
194
+ msg_bytes = (
195
+ msg.encode("ascii")
196
+ if k == "readable"
197
+ else base64.b64decode(msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
+ for _ in range(args.repeat):
200
+ text, msg_rate, tokens_info = generate(
201
+ tokenizer=tokenizer,
202
+ model=model,
203
+ prompt=prompt,
204
+ msg=msg_bytes,
205
+ start_pos_p=[0],
206
+ delta=delta,
207
+ msg_base=base,
208
+ seed_scheme="sha_left_hash",
209
+ window_length=1,
210
+ private_key=0,
211
+ min_new_tokens_ratio=1,
212
+ max_new_tokens_ratio=2,
213
+ num_beams=4,
214
+ repetition_penalty=1.5,
215
+ prompt_size=args.prompt_size,
216
+ )
217
+ results.append(
218
+ {
219
+ "msg_type": msg_type,
220
+ "delta": delta.item(),
221
+ "base": base,
222
+ "perplexity": ModelFactory.compute_perplexity(
223
+ args.judge_model, text
224
+ ),
225
+ "msg_rate": msg_rate,
226
+ "msg_len": len(msg_bytes),
227
+ }
228
+ )
229
+ pbar.update()
230
  return results
231
 
232
 
 
302
  delta_set.add(k[1])
303
  for metric in data:
304
  for msg_type in data[metric]:
305
+ bases[metric][msg_type] = np.array(
306
+ bases[metric][msg_type], dtype=np.int32
307
+ )
308
+ deltas[metric][msg_type] = np.array(
309
+ deltas[metric][msg_type], dtype=np.int32
310
+ )
311
+ values[metric][msg_type] = np.array(
312
+ values[metric][msg_type], dtype=np.float32
313
+ )
314
 
315
  os.makedirs(save_dir, exist_ok=True)
316
  for metric in data:
 
326
  os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"),
327
  bbox_inches="tight",
328
  )
329
+ plt.close(fig)
330
 
331
  os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True)
332
  for metric in data:
 
340
  values[metric][msg_type][mask],
341
  )
342
  plt.savefig(
343
+ os.path.join(
344
+ save_dir,
345
+ f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf",
346
+ ),
347
  bbox_inches="tight",
348
  )
349
+ plt.close(fig)
350
+
351
  os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True)
352
  for metric in data:
353
  for msg_type in data[metric]:
 
360
  values[metric][msg_type][mask],
361
  )
362
  plt.savefig(
363
+ os.path.join(
364
+ save_dir,
365
+ f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf",
366
+ ),
367
  bbox_inches="tight",
368
  )
369
+ plt.close(fig)
370
 
371
 
372
  def main(args):
373
  prompts = load_prompts(
374
  args.num_prompts,
375
+ args.prompt_size,
376
  args.prompts_file if not args.overwrite else None,
377
  )
378
 
379
  msgs_lens = []
380
+ for i in np.linspace(
381
  args.msgs_lengths[0],
382
+ args.msgs_lengths[1],
383
+ int(args.msgs_lengths[2]),
384
  dtype=np.int32,
385
  ):
386
  for _ in range(args.msgs_per_length):
 
420
  process_results(results, args.figs_dir)
421
 
422
 
 
423
  if __name__ == "__main__":
424
  args = create_args()
425
  main(args)
stegno.py CHANGED
@@ -85,6 +85,7 @@ def generate(
85
  do_sample=True,
86
  num_beams=num_beams,
87
  repetition_penalty=float(repetition_penalty),
 
88
  )
89
 
90
  output_tokens = output_tokens[:, prompt_size:]
 
85
  do_sample=True,
86
  num_beams=num_beams,
87
  repetition_penalty=float(repetition_penalty),
88
+ pad_token_id=tokenizer.eos_token_id,
89
  )
90
 
91
  output_tokens = output_tokens[:, prompt_size:]