tnk2908 commited on
Commit
341de97
·
1 Parent(s): 0fa95f3

Finish baseline

Browse files
Files changed (5) hide show
  1. main.py +143 -0
  2. processors.py +165 -0
  3. seed_schemes.py +39 -0
  4. stegno.py +103 -0
  5. utils.py +52 -0
main.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+
4
+ import torch
5
+
6
+ from stegno import generate, decrypt
7
+ from utils import load_model
8
+
9
+
10
+ def create_args():
11
+ parser = ArgumentParser()
12
+
13
+ # Generative model
14
+ parser.add_argument(
15
+ "--gen-model",
16
+ type=str,
17
+ default="openai-community/gpt2",
18
+ help="Generative model (LLM) used to generate text",
19
+ )
20
+ parser.add_argument(
21
+ "--device", type=str, default="cpu", help="Device to load LLM"
22
+ )
23
+ # Stenography params
24
+ parser.add_argument(
25
+ "--gamma",
26
+ type=float,
27
+ default=2.0,
28
+ help="Bias added to scores of tokens in valid list",
29
+ )
30
+ parser.add_argument(
31
+ "--msg-base",
32
+ type=int,
33
+ default=2,
34
+ help="Base of message",
35
+ )
36
+ parser.add_argument(
37
+ "--seed-scheme",
38
+ type=str,
39
+ required=True,
40
+ help="Scheme used to compute the seed",
41
+ )
42
+ parser.add_argument(
43
+ "--window-length",
44
+ type=int,
45
+ default=1,
46
+ help="Length of window to compute the seed",
47
+ )
48
+ parser.add_argument(
49
+ "--salt-key", type=str, default="", help="Path to salt key"
50
+ )
51
+ parser.add_argument(
52
+ "--private-key", type=str, default="", help="Path to private key"
53
+ )
54
+ # Input
55
+ parser.add_argument(
56
+ "--msg", type=str, required=True, help="Path to file containing message"
57
+ )
58
+ parser.add_argument(
59
+ "--prompt", type=str, required=True, help="Prompt used to generate text"
60
+ )
61
+ # Mode
62
+ parser.add_argument(
63
+ "--encrypt",
64
+ action="store_true",
65
+ )
66
+ parser.add_argument(
67
+ "--decrypt",
68
+ action="store_true",
69
+ )
70
+
71
+ return parser.parse_args()
72
+
73
+
74
+ def main(args):
75
+ args.device = torch.device(args.device)
76
+ model, tokenizer = load_model(args.gen_model, args.device)
77
+
78
+ if os.path.isfile(args.salt_key):
79
+ with open(args.salt_key, "r") as f:
80
+ salt_key = int(f.readline())
81
+ else:
82
+ salt_key = None
83
+
84
+ if os.path.isfile(args.private_key):
85
+ with open(args.private_key, "r") as f:
86
+ private_key = int(f.readline())
87
+ else:
88
+ private_key = None
89
+
90
+ if args.encrypt:
91
+ if os.path.isfile(args.msg):
92
+ with open(args.msg, "rb") as f:
93
+ msg = f.read()
94
+ else:
95
+ raise ValueError(f"Message file {args.msg} is not a file")
96
+
97
+ print("=" * os.get_terminal_size().columns)
98
+ print("Encryption Parameters:")
99
+ print(f" GenModel: {args.gen_model}")
100
+ print(f" Prompt: {args.prompt}")
101
+ print(f" Message: {msg}")
102
+ print(f" Gamma: {args.gamma}")
103
+ print(f" Message Base: {args.msg_base}")
104
+ print(f" Seed Scheme: {args.seed_scheme}")
105
+ print(f" Window Length: {args.window_length}")
106
+ print(f" Salt Key: {salt_key}")
107
+ print(f" Private Key: {private_key}")
108
+ print("=" * os.get_terminal_size().columns)
109
+ text = generate(
110
+ tokenizer=tokenizer,
111
+ model=model,
112
+ prompt=args.prompt,
113
+ msg=msg,
114
+ gamma=args.gamma,
115
+ msg_base=args.msg_base,
116
+ seed_scheme=args.seed_scheme,
117
+ window_length=args.window_length,
118
+ salt_key=salt_key,
119
+ private_key=private_key,
120
+ )
121
+ args.text = text
122
+
123
+ print(f"Text contains message:\n{text}")
124
+
125
+ if args.decrypt:
126
+ msgs = decrypt(
127
+ tokenizer=tokenizer,
128
+ device=args.device,
129
+ text=args.text,
130
+ msg_base=args.msg_base,
131
+ seed_scheme=args.seed_scheme,
132
+ window_length=args.window_length,
133
+ salt_key=args.salt_key,
134
+ private_key=args.private_key,
135
+ )
136
+ print("Message:")
137
+ for s, msg in enumerate(msgs):
138
+ print(f"Shift {s}: {msg}")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ args = create_args()
143
+ main(args)
processors.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ import torch
5
+ from transformers import LogitsProcessor
6
+
7
+ from seed_schemes import seed_scheme_factory
8
+ from utils import bytes_to_base, base_to_bytes, get_values_per_byte
9
+
10
+
11
+ class BaseProcessor(object):
12
+ def __init__(
13
+ self,
14
+ msg_base: int,
15
+ vocab: list[int],
16
+ device: torch.device,
17
+ seed_scheme: str,
18
+ window_length: int = 1,
19
+ salt_key: Union[int, None] = None,
20
+ private_key: Union[int, None] = None,
21
+ ):
22
+ """
23
+ Args:
24
+ msg_base: base of the message.
25
+ vocab: vocabulary list.
26
+ device: device to load processor.
27
+ seed_scheme: scheme used to compute the seed.
28
+ window_length: length of window to compute the seed.
29
+ salt_key: salt to add to the seed.
30
+ private_key: private key used to compute the seed.
31
+ """
32
+ # Universal parameters
33
+ self.msg_base = msg_base
34
+ self.vocab = vocab
35
+ self.vocab_size = len(vocab)
36
+ self.device = device
37
+
38
+ # Seed parameters
39
+ self.seed_fn = seed_scheme_factory.get(
40
+ seed_scheme,
41
+ salt_key=salt_key,
42
+ private_key=private_key,
43
+ )
44
+ self.window_length = window_length
45
+
46
+ # Initialize RNG
47
+ self.rng = torch.Generator(device=device)
48
+
49
+ # Compute the ranges of each value in base
50
+ self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
51
+ chunk_size = self.vocab_size / self.msg_base
52
+ r = self.vocab_size % self.msg_base
53
+ self.ranges[1:] = chunk_size
54
+ self.ranges[1 : r + 1] += 1
55
+ self.ranges = torch.cumsum(self.ranges, dim=0)
56
+
57
+ def _seed_rng(self, input_ids: torch.Tensor):
58
+ """
59
+ Set the seed for the rng based on the current sequences.
60
+
61
+ Args:
62
+ input_ids: id in the input sequence.
63
+ """
64
+ seed = self.seed_fn(input_ids[-self.window_length :])
65
+ self.rng.manual_seed(seed)
66
+
67
+ def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int):
68
+ """
69
+ Get ids of tokens in the valid list for the current sequences.
70
+ """
71
+ self._seed_rng(input_ids)
72
+ vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
73
+ vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
74
+
75
+ return vocab_list
76
+
77
+ def _get_value(self, input_ids: torch.Tensor):
78
+ """
79
+ Check whether the token is in the valid list.
80
+ """
81
+ self._seed_rng(input_ids[:-1])
82
+ vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
83
+
84
+ cur_token = input_ids[-1]
85
+ cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
86
+ value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1
87
+
88
+ return value
89
+
90
+
91
+ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
92
+ def __init__(
93
+ self,
94
+ prompt_ids: torch.Tensor,
95
+ msg: bytes,
96
+ gamma: float,
97
+ *args,
98
+ **kwargs
99
+ ):
100
+ """
101
+ Args:
102
+ msg: message to hide in the text.
103
+ gamma: bias add to scores of token in valid list.
104
+ """
105
+ super().__init__(*args, **kwargs)
106
+
107
+ self.start_pos = []
108
+ for i in range(prompt_ids.size(0)):
109
+ self.start_pos.append(prompt_ids[i].size(0))
110
+ self.msg = bytes_to_base(msg, self.msg_base)
111
+ self.gamma = gamma
112
+
113
+ def __call__(
114
+ self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
115
+ ):
116
+ # If the whole message is hidden already, then just return the raw scores.
117
+
118
+ for i, input_ids in enumerate(input_ids_batch):
119
+ cur_pos = input_ids.size(0)
120
+ msg_ptr = cur_pos - self.start_pos[0]
121
+ if msg_ptr >= len(self.msg):
122
+ continue
123
+ scores_batch[i] = self._add_bias_to_valid_list(
124
+ input_ids, scores_batch[i], self.msg[msg_ptr]
125
+ )
126
+
127
+ return scores_batch
128
+
129
+ def _add_bias_to_valid_list(
130
+ self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
131
+ ):
132
+ """
133
+ Add the bias (gamma) to the valid list tokens
134
+ """
135
+ ids = self._get_valid_list_ids(input_ids, value)
136
+ scores[ids] = scores[ids] + self.gamma
137
+ return scores
138
+
139
+ def get_message_len(self):
140
+ return len(self.msg)
141
+
142
+
143
+ class DecryptorProcessor(BaseProcessor):
144
+ def __init__(self, *args, **kwargs):
145
+ super().__init__(*args, **kwargs)
146
+
147
+ def decrypt(self, input_ids_batch: torch.Tensor):
148
+ """
149
+ Decrypt the text sequences.
150
+ """
151
+ shift_msg = []
152
+ for s in range(get_values_per_byte(self.msg_base)):
153
+ msg = []
154
+ bytes_msg = []
155
+ for i, input_ids in enumerate(input_ids_batch):
156
+ msg.append(list())
157
+ for j in range(self.window_length + s, len(input_ids)):
158
+ # TODO: this could be slow. Considering reimplement this.
159
+ value = self._get_value(input_ids[: j + 1])
160
+ msg[i].append(value)
161
+
162
+ bytes_msg.append(base_to_bytes(msg[i], self.msg_base))
163
+ shift_msg.append(bytes_msg)
164
+
165
+ return shift_msg
seed_schemes.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Callable
2
+
3
+ import torch
4
+
5
+
6
+ class SeedSchemeFactory:
7
+ def __init__(self):
8
+ self.seed_scheme_dict = dict()
9
+
10
+ def register(self, name: str, seed_scheme: type):
11
+ """
12
+ Register the hash scheme by name. Hash scheme must be callable.
13
+
14
+ Args:
15
+ name: name of seed scheme.
16
+ func: seed function.
17
+ """
18
+ self.seed_scheme_dict[name] = seed_scheme
19
+
20
+ def get(self, name: str, **kwargs):
21
+ """
22
+ Get the hash scheme by name.
23
+
24
+ Args:
25
+ name: name of seed scheme.
26
+ """
27
+ return self.seed_scheme_dict[name](**kwargs)
28
+
29
+
30
+ class DummyHash:
31
+ def __init__(self, *args, **kwargs):
32
+ pass
33
+
34
+ def __call__(self, input_ids: torch.Tensor):
35
+ return input_ids[-1].item()
36
+
37
+
38
+ seed_scheme_factory = SeedSchemeFactory()
39
+ seed_scheme_factory.register("dummy_hash", DummyHash)
stegno.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ import transformers
5
+
6
+ from processors import EncryptorLogitsProcessor, DecryptorProcessor
7
+
8
+
9
+ def generate(
10
+ tokenizer,
11
+ model,
12
+ prompt: str,
13
+ msg: bytes,
14
+ gamma: float,
15
+ msg_base: int,
16
+ seed_scheme: str,
17
+ window_length: int = 1,
18
+ salt_key: Union[int, None] = None,
19
+ private_key: Union[int, None] = None,
20
+ ):
21
+ """
22
+ Generate the sequence containing the hidden data.
23
+
24
+ Args:
25
+ tokenizer: tokenizer to use.
26
+ model: generative model to use.
27
+ prompt: input prompt.
28
+ msg: message to hide in the text.
29
+ gamma: bias add to scores of token in valid list.
30
+ msg_base: base of the message.
31
+ seed_scheme: scheme used to compute the seed.
32
+ window_length: length of window to compute the seed.
33
+ salt_key: salt to add to the seed.
34
+ private_key: private key used to compute the seed.
35
+
36
+ """
37
+ tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
38
+ logits_processor = EncryptorLogitsProcessor(
39
+ prompt_ids=tokenized_input.input_ids,
40
+ msg=msg,
41
+ gamma=gamma,
42
+ msg_base=msg_base,
43
+ vocab=list(tokenizer.get_vocab().values()),
44
+ device=model.device,
45
+ seed_scheme=seed_scheme,
46
+ window_length=window_length,
47
+ salt_key=salt_key,
48
+ private_key=private_key,
49
+ )
50
+ output_tokens = model.generate(
51
+ **tokenized_input,
52
+ logits_processor=transformers.LogitsProcessorList([logits_processor]),
53
+ min_new_tokens=logits_processor.get_message_len(),
54
+ max_new_tokens=logits_processor.get_message_len() * 2,
55
+ do_sample=True,
56
+ num_beams=4,
57
+ )
58
+ output_text = tokenizer.batch_decode(
59
+ output_tokens, skip_special_tokens=True
60
+ )[0]
61
+
62
+ return output_text
63
+
64
+
65
+ def decrypt(
66
+ tokenizer,
67
+ device: torch.device,
68
+ text: str,
69
+ msg_base: int,
70
+ seed_scheme: str,
71
+ window_length: int = 1,
72
+ salt_key: Union[int, None] = None,
73
+ private_key: Union[int, None] = None,
74
+ ):
75
+ """
76
+ Extract the hidden data from the generated sequence.
77
+
78
+ Args:
79
+ tokenizer: tokenizer to use.
80
+ text: text to decode.
81
+ msg_base: base of the message.
82
+ gamma: bias added to scores of valid list.
83
+ seed_scheme: scheme used to compute the seed.
84
+ window_length: length of window to compute the seed.
85
+ salt_key: salt to add to the seed.
86
+ private_key: private key used to compute the seed.
87
+ """
88
+ tokenized_input = tokenizer(text, return_tensors="pt").to(device)
89
+
90
+ decryptor = DecryptorProcessor(
91
+ msg_base=msg_base,
92
+ vocab=list(tokenizer.get_vocab().values()),
93
+ device=device,
94
+ seed_scheme=seed_scheme,
95
+ window_length=window_length,
96
+ salt_key=salt_key,
97
+ private_key=private_key,
98
+ )
99
+
100
+ msg = decryptor.decrypt(tokenized_input.input_ids)
101
+
102
+ return msg
103
+
utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+
5
+ def get_values_per_byte(base: int):
6
+ values_per_byte = 1
7
+ tmp = 255 // base
8
+ while tmp > 0:
9
+ values_per_byte += 1
10
+ tmp = tmp // base
11
+ return values_per_byte
12
+
13
+
14
+ def bytes_to_base(m: bytes, base: int) -> list[int]:
15
+ values_per_byte = get_values_per_byte(base)
16
+ values = []
17
+ for b in m:
18
+ tmp = []
19
+ for i in range(values_per_byte):
20
+ tmp.append(b % base)
21
+ b = b // base
22
+ values.extend(tmp[::-1])
23
+
24
+ return values
25
+
26
+
27
+ def base_to_bytes(values: list[int], base: int) -> bytes:
28
+ values_per_byte = get_values_per_byte(base)
29
+
30
+ arr = bytearray()
31
+
32
+ i = 0
33
+ while i < len(values):
34
+ tmp = 0
35
+ for _ in range(values_per_byte):
36
+ tmp = tmp * base + values[i]
37
+ i += 1
38
+ if i >= len(values):
39
+ break
40
+ arr.append(tmp)
41
+
42
+ return bytes(arr)
43
+
44
+
45
+ def load_model(name: str, device: torch.device):
46
+ model = AutoModelForCausalLM.from_pretrained(name)
47
+ model.to(device)
48
+ model.eval()
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(name)
51
+
52
+ return model, tokenizer