tnk2908's picture
Improve UI and reduce repetitiveness of generation
ee83d59
raw
history blame
3.95 kB
from typing import Union
import torch
import transformers
from processors import EncryptorLogitsProcessor, DecryptorProcessor
def generate(
tokenizer,
model,
prompt: str,
msg: bytes,
start_pos_p: list[int],
gamma: float,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
max_new_tokens_ratio: float = 2,
num_beams: int = 4,
repetition_penalty: float = 1.0,
):
"""
Generate the sequence containing the hidden data.
Args:
tokenizer: tokenizer to use.
model: generative model to use.
prompt: input prompt.
msg: message to hide in the text.
gamma: bias add to scores of token in valid list.
msg_base: base of the message.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
if len(start_pos_p) == 1:
start_pos = start_pos_p[0]
else:
start_pos = torch.randint(
start_pos_p[0], start_pos_p[1] + 1, (1,)
).item()
start_pos = int(start_pos) + window_length
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_size = tokenized_input.input_ids.size(1)
logits_processor = EncryptorLogitsProcessor(
prompt_ids=tokenized_input.input_ids,
msg=msg,
start_pos=start_pos,
gamma=gamma,
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
tokenizer=tokenizer,
device=model.device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
min_length = prompt_size + start_pos + logits_processor.get_message_len()
max_length = prompt_size + int(
start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
)
max_length = min(max_length, tokenizer.model_max_length)
min_length = min(min_length, max_length)
output_tokens = model.generate(
**tokenized_input,
logits_processor=transformers.LogitsProcessorList([logits_processor]),
min_length=min_length,
max_length=max_length,
do_sample=True,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
)
output_tokens = output_tokens[:, prompt_size:]
output_text = tokenizer.batch_decode(
output_tokens, skip_special_tokens=True
)[0]
output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
model.device
)
msg_rates, tokens_infos = logits_processor.validate(output_tokens_post.input_ids)
return output_text, msg_rates[0], tokens_infos[0]
def decrypt(
tokenizer,
device: torch.device,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Extract the hidden data from the generated sequence.
Args:
tokenizer: tokenizer to use.
text: text to decode.
msg_base: base of the message.
gamma: bias added to scores of valid list.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
decryptor = DecryptorProcessor(
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
device=device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
msg = decryptor.decrypt(tokenized_input.input_ids)
return msg