Spaces:
Sleeping
Sleeping
Finish baseline
Browse files- main.py +143 -0
- processors.py +165 -0
- seed_schemes.py +39 -0
- stegno.py +103 -0
- utils.py +52 -0
main.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from stegno import generate, decrypt
|
7 |
+
from utils import load_model
|
8 |
+
|
9 |
+
|
10 |
+
def create_args():
|
11 |
+
parser = ArgumentParser()
|
12 |
+
|
13 |
+
# Generative model
|
14 |
+
parser.add_argument(
|
15 |
+
"--gen-model",
|
16 |
+
type=str,
|
17 |
+
default="openai-community/gpt2",
|
18 |
+
help="Generative model (LLM) used to generate text",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--device", type=str, default="cpu", help="Device to load LLM"
|
22 |
+
)
|
23 |
+
# Stenography params
|
24 |
+
parser.add_argument(
|
25 |
+
"--gamma",
|
26 |
+
type=float,
|
27 |
+
default=2.0,
|
28 |
+
help="Bias added to scores of tokens in valid list",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--msg-base",
|
32 |
+
type=int,
|
33 |
+
default=2,
|
34 |
+
help="Base of message",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--seed-scheme",
|
38 |
+
type=str,
|
39 |
+
required=True,
|
40 |
+
help="Scheme used to compute the seed",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--window-length",
|
44 |
+
type=int,
|
45 |
+
default=1,
|
46 |
+
help="Length of window to compute the seed",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--salt-key", type=str, default="", help="Path to salt key"
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--private-key", type=str, default="", help="Path to private key"
|
53 |
+
)
|
54 |
+
# Input
|
55 |
+
parser.add_argument(
|
56 |
+
"--msg", type=str, required=True, help="Path to file containing message"
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--prompt", type=str, required=True, help="Prompt used to generate text"
|
60 |
+
)
|
61 |
+
# Mode
|
62 |
+
parser.add_argument(
|
63 |
+
"--encrypt",
|
64 |
+
action="store_true",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--decrypt",
|
68 |
+
action="store_true",
|
69 |
+
)
|
70 |
+
|
71 |
+
return parser.parse_args()
|
72 |
+
|
73 |
+
|
74 |
+
def main(args):
|
75 |
+
args.device = torch.device(args.device)
|
76 |
+
model, tokenizer = load_model(args.gen_model, args.device)
|
77 |
+
|
78 |
+
if os.path.isfile(args.salt_key):
|
79 |
+
with open(args.salt_key, "r") as f:
|
80 |
+
salt_key = int(f.readline())
|
81 |
+
else:
|
82 |
+
salt_key = None
|
83 |
+
|
84 |
+
if os.path.isfile(args.private_key):
|
85 |
+
with open(args.private_key, "r") as f:
|
86 |
+
private_key = int(f.readline())
|
87 |
+
else:
|
88 |
+
private_key = None
|
89 |
+
|
90 |
+
if args.encrypt:
|
91 |
+
if os.path.isfile(args.msg):
|
92 |
+
with open(args.msg, "rb") as f:
|
93 |
+
msg = f.read()
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Message file {args.msg} is not a file")
|
96 |
+
|
97 |
+
print("=" * os.get_terminal_size().columns)
|
98 |
+
print("Encryption Parameters:")
|
99 |
+
print(f" GenModel: {args.gen_model}")
|
100 |
+
print(f" Prompt: {args.prompt}")
|
101 |
+
print(f" Message: {msg}")
|
102 |
+
print(f" Gamma: {args.gamma}")
|
103 |
+
print(f" Message Base: {args.msg_base}")
|
104 |
+
print(f" Seed Scheme: {args.seed_scheme}")
|
105 |
+
print(f" Window Length: {args.window_length}")
|
106 |
+
print(f" Salt Key: {salt_key}")
|
107 |
+
print(f" Private Key: {private_key}")
|
108 |
+
print("=" * os.get_terminal_size().columns)
|
109 |
+
text = generate(
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
model=model,
|
112 |
+
prompt=args.prompt,
|
113 |
+
msg=msg,
|
114 |
+
gamma=args.gamma,
|
115 |
+
msg_base=args.msg_base,
|
116 |
+
seed_scheme=args.seed_scheme,
|
117 |
+
window_length=args.window_length,
|
118 |
+
salt_key=salt_key,
|
119 |
+
private_key=private_key,
|
120 |
+
)
|
121 |
+
args.text = text
|
122 |
+
|
123 |
+
print(f"Text contains message:\n{text}")
|
124 |
+
|
125 |
+
if args.decrypt:
|
126 |
+
msgs = decrypt(
|
127 |
+
tokenizer=tokenizer,
|
128 |
+
device=args.device,
|
129 |
+
text=args.text,
|
130 |
+
msg_base=args.msg_base,
|
131 |
+
seed_scheme=args.seed_scheme,
|
132 |
+
window_length=args.window_length,
|
133 |
+
salt_key=args.salt_key,
|
134 |
+
private_key=args.private_key,
|
135 |
+
)
|
136 |
+
print("Message:")
|
137 |
+
for s, msg in enumerate(msgs):
|
138 |
+
print(f"Shift {s}: {msg}")
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
args = create_args()
|
143 |
+
main(args)
|
processors.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import LogitsProcessor
|
6 |
+
|
7 |
+
from seed_schemes import seed_scheme_factory
|
8 |
+
from utils import bytes_to_base, base_to_bytes, get_values_per_byte
|
9 |
+
|
10 |
+
|
11 |
+
class BaseProcessor(object):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
msg_base: int,
|
15 |
+
vocab: list[int],
|
16 |
+
device: torch.device,
|
17 |
+
seed_scheme: str,
|
18 |
+
window_length: int = 1,
|
19 |
+
salt_key: Union[int, None] = None,
|
20 |
+
private_key: Union[int, None] = None,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
msg_base: base of the message.
|
25 |
+
vocab: vocabulary list.
|
26 |
+
device: device to load processor.
|
27 |
+
seed_scheme: scheme used to compute the seed.
|
28 |
+
window_length: length of window to compute the seed.
|
29 |
+
salt_key: salt to add to the seed.
|
30 |
+
private_key: private key used to compute the seed.
|
31 |
+
"""
|
32 |
+
# Universal parameters
|
33 |
+
self.msg_base = msg_base
|
34 |
+
self.vocab = vocab
|
35 |
+
self.vocab_size = len(vocab)
|
36 |
+
self.device = device
|
37 |
+
|
38 |
+
# Seed parameters
|
39 |
+
self.seed_fn = seed_scheme_factory.get(
|
40 |
+
seed_scheme,
|
41 |
+
salt_key=salt_key,
|
42 |
+
private_key=private_key,
|
43 |
+
)
|
44 |
+
self.window_length = window_length
|
45 |
+
|
46 |
+
# Initialize RNG
|
47 |
+
self.rng = torch.Generator(device=device)
|
48 |
+
|
49 |
+
# Compute the ranges of each value in base
|
50 |
+
self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
|
51 |
+
chunk_size = self.vocab_size / self.msg_base
|
52 |
+
r = self.vocab_size % self.msg_base
|
53 |
+
self.ranges[1:] = chunk_size
|
54 |
+
self.ranges[1 : r + 1] += 1
|
55 |
+
self.ranges = torch.cumsum(self.ranges, dim=0)
|
56 |
+
|
57 |
+
def _seed_rng(self, input_ids: torch.Tensor):
|
58 |
+
"""
|
59 |
+
Set the seed for the rng based on the current sequences.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
input_ids: id in the input sequence.
|
63 |
+
"""
|
64 |
+
seed = self.seed_fn(input_ids[-self.window_length :])
|
65 |
+
self.rng.manual_seed(seed)
|
66 |
+
|
67 |
+
def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int):
|
68 |
+
"""
|
69 |
+
Get ids of tokens in the valid list for the current sequences.
|
70 |
+
"""
|
71 |
+
self._seed_rng(input_ids)
|
72 |
+
vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
|
73 |
+
vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
|
74 |
+
|
75 |
+
return vocab_list
|
76 |
+
|
77 |
+
def _get_value(self, input_ids: torch.Tensor):
|
78 |
+
"""
|
79 |
+
Check whether the token is in the valid list.
|
80 |
+
"""
|
81 |
+
self._seed_rng(input_ids[:-1])
|
82 |
+
vocab_perm = torch.randperm(self.vocab_size, generator=self.rng)
|
83 |
+
|
84 |
+
cur_token = input_ids[-1]
|
85 |
+
cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
|
86 |
+
value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1
|
87 |
+
|
88 |
+
return value
|
89 |
+
|
90 |
+
|
91 |
+
class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
prompt_ids: torch.Tensor,
|
95 |
+
msg: bytes,
|
96 |
+
gamma: float,
|
97 |
+
*args,
|
98 |
+
**kwargs
|
99 |
+
):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
msg: message to hide in the text.
|
103 |
+
gamma: bias add to scores of token in valid list.
|
104 |
+
"""
|
105 |
+
super().__init__(*args, **kwargs)
|
106 |
+
|
107 |
+
self.start_pos = []
|
108 |
+
for i in range(prompt_ids.size(0)):
|
109 |
+
self.start_pos.append(prompt_ids[i].size(0))
|
110 |
+
self.msg = bytes_to_base(msg, self.msg_base)
|
111 |
+
self.gamma = gamma
|
112 |
+
|
113 |
+
def __call__(
|
114 |
+
self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
|
115 |
+
):
|
116 |
+
# If the whole message is hidden already, then just return the raw scores.
|
117 |
+
|
118 |
+
for i, input_ids in enumerate(input_ids_batch):
|
119 |
+
cur_pos = input_ids.size(0)
|
120 |
+
msg_ptr = cur_pos - self.start_pos[0]
|
121 |
+
if msg_ptr >= len(self.msg):
|
122 |
+
continue
|
123 |
+
scores_batch[i] = self._add_bias_to_valid_list(
|
124 |
+
input_ids, scores_batch[i], self.msg[msg_ptr]
|
125 |
+
)
|
126 |
+
|
127 |
+
return scores_batch
|
128 |
+
|
129 |
+
def _add_bias_to_valid_list(
|
130 |
+
self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
|
131 |
+
):
|
132 |
+
"""
|
133 |
+
Add the bias (gamma) to the valid list tokens
|
134 |
+
"""
|
135 |
+
ids = self._get_valid_list_ids(input_ids, value)
|
136 |
+
scores[ids] = scores[ids] + self.gamma
|
137 |
+
return scores
|
138 |
+
|
139 |
+
def get_message_len(self):
|
140 |
+
return len(self.msg)
|
141 |
+
|
142 |
+
|
143 |
+
class DecryptorProcessor(BaseProcessor):
|
144 |
+
def __init__(self, *args, **kwargs):
|
145 |
+
super().__init__(*args, **kwargs)
|
146 |
+
|
147 |
+
def decrypt(self, input_ids_batch: torch.Tensor):
|
148 |
+
"""
|
149 |
+
Decrypt the text sequences.
|
150 |
+
"""
|
151 |
+
shift_msg = []
|
152 |
+
for s in range(get_values_per_byte(self.msg_base)):
|
153 |
+
msg = []
|
154 |
+
bytes_msg = []
|
155 |
+
for i, input_ids in enumerate(input_ids_batch):
|
156 |
+
msg.append(list())
|
157 |
+
for j in range(self.window_length + s, len(input_ids)):
|
158 |
+
# TODO: this could be slow. Considering reimplement this.
|
159 |
+
value = self._get_value(input_ids[: j + 1])
|
160 |
+
msg[i].append(value)
|
161 |
+
|
162 |
+
bytes_msg.append(base_to_bytes(msg[i], self.msg_base))
|
163 |
+
shift_msg.append(bytes_msg)
|
164 |
+
|
165 |
+
return shift_msg
|
seed_schemes.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class SeedSchemeFactory:
|
7 |
+
def __init__(self):
|
8 |
+
self.seed_scheme_dict = dict()
|
9 |
+
|
10 |
+
def register(self, name: str, seed_scheme: type):
|
11 |
+
"""
|
12 |
+
Register the hash scheme by name. Hash scheme must be callable.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
name: name of seed scheme.
|
16 |
+
func: seed function.
|
17 |
+
"""
|
18 |
+
self.seed_scheme_dict[name] = seed_scheme
|
19 |
+
|
20 |
+
def get(self, name: str, **kwargs):
|
21 |
+
"""
|
22 |
+
Get the hash scheme by name.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
name: name of seed scheme.
|
26 |
+
"""
|
27 |
+
return self.seed_scheme_dict[name](**kwargs)
|
28 |
+
|
29 |
+
|
30 |
+
class DummyHash:
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def __call__(self, input_ids: torch.Tensor):
|
35 |
+
return input_ids[-1].item()
|
36 |
+
|
37 |
+
|
38 |
+
seed_scheme_factory = SeedSchemeFactory()
|
39 |
+
seed_scheme_factory.register("dummy_hash", DummyHash)
|
stegno.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
from processors import EncryptorLogitsProcessor, DecryptorProcessor
|
7 |
+
|
8 |
+
|
9 |
+
def generate(
|
10 |
+
tokenizer,
|
11 |
+
model,
|
12 |
+
prompt: str,
|
13 |
+
msg: bytes,
|
14 |
+
gamma: float,
|
15 |
+
msg_base: int,
|
16 |
+
seed_scheme: str,
|
17 |
+
window_length: int = 1,
|
18 |
+
salt_key: Union[int, None] = None,
|
19 |
+
private_key: Union[int, None] = None,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
Generate the sequence containing the hidden data.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
tokenizer: tokenizer to use.
|
26 |
+
model: generative model to use.
|
27 |
+
prompt: input prompt.
|
28 |
+
msg: message to hide in the text.
|
29 |
+
gamma: bias add to scores of token in valid list.
|
30 |
+
msg_base: base of the message.
|
31 |
+
seed_scheme: scheme used to compute the seed.
|
32 |
+
window_length: length of window to compute the seed.
|
33 |
+
salt_key: salt to add to the seed.
|
34 |
+
private_key: private key used to compute the seed.
|
35 |
+
|
36 |
+
"""
|
37 |
+
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
|
38 |
+
logits_processor = EncryptorLogitsProcessor(
|
39 |
+
prompt_ids=tokenized_input.input_ids,
|
40 |
+
msg=msg,
|
41 |
+
gamma=gamma,
|
42 |
+
msg_base=msg_base,
|
43 |
+
vocab=list(tokenizer.get_vocab().values()),
|
44 |
+
device=model.device,
|
45 |
+
seed_scheme=seed_scheme,
|
46 |
+
window_length=window_length,
|
47 |
+
salt_key=salt_key,
|
48 |
+
private_key=private_key,
|
49 |
+
)
|
50 |
+
output_tokens = model.generate(
|
51 |
+
**tokenized_input,
|
52 |
+
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
53 |
+
min_new_tokens=logits_processor.get_message_len(),
|
54 |
+
max_new_tokens=logits_processor.get_message_len() * 2,
|
55 |
+
do_sample=True,
|
56 |
+
num_beams=4,
|
57 |
+
)
|
58 |
+
output_text = tokenizer.batch_decode(
|
59 |
+
output_tokens, skip_special_tokens=True
|
60 |
+
)[0]
|
61 |
+
|
62 |
+
return output_text
|
63 |
+
|
64 |
+
|
65 |
+
def decrypt(
|
66 |
+
tokenizer,
|
67 |
+
device: torch.device,
|
68 |
+
text: str,
|
69 |
+
msg_base: int,
|
70 |
+
seed_scheme: str,
|
71 |
+
window_length: int = 1,
|
72 |
+
salt_key: Union[int, None] = None,
|
73 |
+
private_key: Union[int, None] = None,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
Extract the hidden data from the generated sequence.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
tokenizer: tokenizer to use.
|
80 |
+
text: text to decode.
|
81 |
+
msg_base: base of the message.
|
82 |
+
gamma: bias added to scores of valid list.
|
83 |
+
seed_scheme: scheme used to compute the seed.
|
84 |
+
window_length: length of window to compute the seed.
|
85 |
+
salt_key: salt to add to the seed.
|
86 |
+
private_key: private key used to compute the seed.
|
87 |
+
"""
|
88 |
+
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
|
89 |
+
|
90 |
+
decryptor = DecryptorProcessor(
|
91 |
+
msg_base=msg_base,
|
92 |
+
vocab=list(tokenizer.get_vocab().values()),
|
93 |
+
device=device,
|
94 |
+
seed_scheme=seed_scheme,
|
95 |
+
window_length=window_length,
|
96 |
+
salt_key=salt_key,
|
97 |
+
private_key=private_key,
|
98 |
+
)
|
99 |
+
|
100 |
+
msg = decryptor.decrypt(tokenized_input.input_ids)
|
101 |
+
|
102 |
+
return msg
|
103 |
+
|
utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
|
5 |
+
def get_values_per_byte(base: int):
|
6 |
+
values_per_byte = 1
|
7 |
+
tmp = 255 // base
|
8 |
+
while tmp > 0:
|
9 |
+
values_per_byte += 1
|
10 |
+
tmp = tmp // base
|
11 |
+
return values_per_byte
|
12 |
+
|
13 |
+
|
14 |
+
def bytes_to_base(m: bytes, base: int) -> list[int]:
|
15 |
+
values_per_byte = get_values_per_byte(base)
|
16 |
+
values = []
|
17 |
+
for b in m:
|
18 |
+
tmp = []
|
19 |
+
for i in range(values_per_byte):
|
20 |
+
tmp.append(b % base)
|
21 |
+
b = b // base
|
22 |
+
values.extend(tmp[::-1])
|
23 |
+
|
24 |
+
return values
|
25 |
+
|
26 |
+
|
27 |
+
def base_to_bytes(values: list[int], base: int) -> bytes:
|
28 |
+
values_per_byte = get_values_per_byte(base)
|
29 |
+
|
30 |
+
arr = bytearray()
|
31 |
+
|
32 |
+
i = 0
|
33 |
+
while i < len(values):
|
34 |
+
tmp = 0
|
35 |
+
for _ in range(values_per_byte):
|
36 |
+
tmp = tmp * base + values[i]
|
37 |
+
i += 1
|
38 |
+
if i >= len(values):
|
39 |
+
break
|
40 |
+
arr.append(tmp)
|
41 |
+
|
42 |
+
return bytes(arr)
|
43 |
+
|
44 |
+
|
45 |
+
def load_model(name: str, device: torch.device):
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(name)
|
47 |
+
model.to(device)
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
51 |
+
|
52 |
+
return model, tokenizer
|