Spaces:
Sleeping
Sleeping
File size: 5,135 Bytes
341de97 0186ed1 341de97 7e8d9b9 cc8b2eb 341de97 52c67ef 11c7796 0186ed1 ee83d59 0186ed1 341de97 0852a55 341de97 0852a55 cc8b2eb 341de97 0852a55 341de97 7e8d9b9 0186ed1 776d1db 0186ed1 776d1db 0186ed1 7e8d9b9 341de97 0186ed1 341de97 7e8d9b9 cc8b2eb 341de97 1f125f1 341de97 c231729 52c67ef c231729 52c67ef 7e8d9b9 ee83d59 c231729 341de97 0186ed1 341de97 ee83d59 0186ed1 ee83d59 9da31aa 0186ed1 341de97 9bd2dae 7e8d9b9 0186ed1 52c67ef 0186ed1 52c67ef 0186ed1 52c67ef 776d1db 50b0c43 341de97 0186ed1 341de97 cc8b2eb 341de97 0852a55 341de97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from typing import Union
import torch
import transformers
from processors import EncryptorLogitsProcessor, DecryptorProcessor
def generate(
tokenizer,
model,
prompt: str | list[str],
msg: bytes,
start_pos_p: list[int],
delta: float,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
min_new_tokens_ratio: float = 1,
max_new_tokens_ratio: float = 2,
do_sample: bool = True,
num_beams: int = 1,
repetition_penalty: float = 1.0,
generator: torch.Generator | None = None,
):
"""
Generate the sequence containing the hidden data. This supports batch input/output.
Args:
tokenizer: tokenizer to use.
model: generative model to use.
prompt: input prompt.
msg: message to hide in the text.
start_pos_p: start position to hide message.
delta: bias add to scores of token in valid list.
msg_base: base of the message.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
min_new_tokens_ratio: ratio between min generated tokens and required token length.
min_new_tokens_ratio: ratio between max generated tokens and required token length.
do_sample: whether to do sampling or greedy generation.
num_beams: number of beams used in beam search.
repetition_penalty: penalty to avoid repetitiveness.
generator: generation used to genereate. This is mainly used to produce deterministic results.
Returns:
generated texts, hidden message rates, tokens information
"""
if len(start_pos_p) == 1:
start_pos = start_pos_p[0]
else:
start_pos = torch.randint(
start_pos_p[0], start_pos_p[1] + 1, (1,)
).item()
start_pos = int(start_pos) + window_length
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenized_input = tokenizer(prompt, return_tensors="pt", padding=True).to(
model.device
)
prompt_size = tokenized_input.input_ids.size(1)
logits_processor = EncryptorLogitsProcessor(
prompt_ids=tokenized_input.input_ids,
msg=msg,
start_pos=start_pos,
delta=delta,
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
tokenizer=tokenizer,
device=model.device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
min_length = int(
prompt_size
+ start_pos
+ logits_processor.get_message_len() * min_new_tokens_ratio
)
max_length = int(
prompt_size
+ start_pos
+ logits_processor.get_message_len() * max_new_tokens_ratio
)
max_length = min(max_length, tokenizer.model_max_length)
min_length = min(min_length, max_length)
output_tokens = model.generate(
**tokenized_input,
logits_processor=transformers.LogitsProcessorList([logits_processor]),
min_length=min_length,
max_length=max_length,
do_sample=do_sample,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
pad_token_id=tokenizer.eos_token_id,
generator=generator,
)
tokenizer.padding_side = "right"
output_tokens = output_tokens[:, prompt_size:]
output_texts = tokenizer.batch_decode(
output_tokens, skip_special_tokens=True
)
output_tokens_post = tokenizer(
output_texts,
return_tensors="pt",
add_special_tokens=False,
padding=True,
).to(model.device)
msg_rates, tokens_infos = logits_processor.validate(
output_tokens_post.input_ids
)
return output_texts, msg_rates, tokens_infos
def decrypt(
tokenizer,
device: torch.device,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Extract the hidden data from the generated sequence.
Args:
tokenizer: tokenizer to use.
text: text to decode.
msg_base: base of the message.
delta: bias added to scores of valid list.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
Returns:
shifted versions of the message
"""
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
decryptor = DecryptorProcessor(
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
device=device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
msg = decryptor.decrypt(tokenized_input.input_ids)
return msg
|