tnk2908's picture
Add parameters to RestAPI
c231729
raw
history blame
5.14 kB
from typing import Union
import torch
import transformers
from processors import EncryptorLogitsProcessor, DecryptorProcessor
def generate(
tokenizer,
model,
prompt: str | list[str],
msg: bytes,
start_pos_p: list[int],
delta: float,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
min_new_tokens_ratio: float = 1,
max_new_tokens_ratio: float = 2,
do_sample: bool = True,
num_beams: int = 1,
repetition_penalty: float = 1.0,
generator: torch.Generator | None = None,
):
"""
Generate the sequence containing the hidden data. This supports batch input/output.
Args:
tokenizer: tokenizer to use.
model: generative model to use.
prompt: input prompt.
msg: message to hide in the text.
start_pos_p: start position to hide message.
delta: bias add to scores of token in valid list.
msg_base: base of the message.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
min_new_tokens_ratio: ratio between min generated tokens and required token length.
min_new_tokens_ratio: ratio between max generated tokens and required token length.
do_sample: whether to do sampling or greedy generation.
num_beams: number of beams used in beam search.
repetition_penalty: penalty to avoid repetitiveness.
generator: generation used to genereate. This is mainly used to produce deterministic results.
Returns:
generated texts, hidden message rates, tokens information
"""
if len(start_pos_p) == 1:
start_pos = start_pos_p[0]
else:
start_pos = torch.randint(
start_pos_p[0], start_pos_p[1] + 1, (1,)
).item()
start_pos = int(start_pos) + window_length
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenized_input = tokenizer(prompt, return_tensors="pt", padding=True).to(
model.device
)
prompt_size = tokenized_input.input_ids.size(1)
logits_processor = EncryptorLogitsProcessor(
prompt_ids=tokenized_input.input_ids,
msg=msg,
start_pos=start_pos,
delta=delta,
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
tokenizer=tokenizer,
device=model.device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
min_length = int(
prompt_size
+ start_pos
+ logits_processor.get_message_len() * min_new_tokens_ratio
)
max_length = int(
prompt_size
+ start_pos
+ logits_processor.get_message_len() * max_new_tokens_ratio
)
max_length = min(max_length, tokenizer.model_max_length)
min_length = min(min_length, max_length)
output_tokens = model.generate(
**tokenized_input,
logits_processor=transformers.LogitsProcessorList([logits_processor]),
min_length=min_length,
max_length=max_length,
do_sample=do_sample,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
pad_token_id=tokenizer.eos_token_id,
generator=generator,
)
tokenizer.padding_side = "right"
output_tokens = output_tokens[:, prompt_size:]
output_texts = tokenizer.batch_decode(
output_tokens, skip_special_tokens=True
)
output_tokens_post = tokenizer(
output_texts,
return_tensors="pt",
add_special_tokens=False,
padding=True,
).to(model.device)
msg_rates, tokens_infos = logits_processor.validate(
output_tokens_post.input_ids
)
return output_texts, msg_rates, tokens_infos
def decrypt(
tokenizer,
device: torch.device,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Extract the hidden data from the generated sequence.
Args:
tokenizer: tokenizer to use.
text: text to decode.
msg_base: base of the message.
delta: bias added to scores of valid list.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
Returns:
shifted versions of the message
"""
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
decryptor = DecryptorProcessor(
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
device=device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
msg = decryptor.decrypt(tokenized_input.input_ids)
return msg