File size: 5,135 Bytes
341de97
 
 
 
 
 
 
 
 
 
 
0186ed1
341de97
7e8d9b9
cc8b2eb
341de97
 
 
 
 
52c67ef
11c7796
0186ed1
 
ee83d59
0186ed1
341de97
 
0852a55
341de97
 
 
 
 
 
0852a55
cc8b2eb
341de97
 
 
 
 
0852a55
 
 
 
 
 
 
 
341de97
7e8d9b9
 
 
 
 
 
 
0186ed1
776d1db
0186ed1
776d1db
 
 
0186ed1
7e8d9b9
341de97
0186ed1
341de97
7e8d9b9
cc8b2eb
341de97
 
1f125f1
341de97
 
 
 
 
 
c231729
52c67ef
 
 
 
c231729
52c67ef
 
 
7e8d9b9
ee83d59
 
c231729
341de97
0186ed1
341de97
ee83d59
 
0186ed1
ee83d59
 
9da31aa
0186ed1
341de97
9bd2dae
7e8d9b9
 
0186ed1
52c67ef
0186ed1
52c67ef
0186ed1
 
 
 
52c67ef
776d1db
50b0c43
 
 
341de97
0186ed1
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc8b2eb
341de97
 
 
 
0852a55
 
341de97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from typing import Union

import torch
import transformers

from processors import EncryptorLogitsProcessor, DecryptorProcessor


def generate(
    tokenizer,
    model,
    prompt: str | list[str],
    msg: bytes,
    start_pos_p: list[int],
    delta: float,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
    min_new_tokens_ratio: float = 1,
    max_new_tokens_ratio: float = 2,
    do_sample: bool = True,
    num_beams: int = 1,
    repetition_penalty: float = 1.0,
    generator: torch.Generator | None = None,
):
    """
    Generate the sequence containing the hidden data. This supports batch input/output.

    Args:
        tokenizer: tokenizer to use.
        model: generative model to use.
        prompt: input prompt.
        msg: message to hide in the text.
        start_pos_p: start position to hide message.
        delta: bias add to scores of token in valid list.
        msg_base: base of the message.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.
        min_new_tokens_ratio: ratio between min generated tokens and required token length.
        min_new_tokens_ratio: ratio between max generated tokens and required token length.
        do_sample: whether to do sampling or greedy generation.
        num_beams: number of beams used in beam search.
        repetition_penalty: penalty to avoid repetitiveness.
        generator: generation used to genereate. This is mainly used to produce deterministic results.
    Returns:
        generated texts, hidden message rates, tokens information
    """
    if len(start_pos_p) == 1:
        start_pos = start_pos_p[0]
    else:
        start_pos = torch.randint(
            start_pos_p[0], start_pos_p[1] + 1, (1,)
        ).item()
    start_pos = int(start_pos) + window_length
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    tokenized_input = tokenizer(prompt, return_tensors="pt", padding=True).to(
        model.device
    )
    prompt_size = tokenized_input.input_ids.size(1)

    logits_processor = EncryptorLogitsProcessor(
        prompt_ids=tokenized_input.input_ids,
        msg=msg,
        start_pos=start_pos,
        delta=delta,
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        tokenizer=tokenizer,
        device=model.device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )
    min_length = int(
        prompt_size
        + start_pos
        + logits_processor.get_message_len() * min_new_tokens_ratio
    )
    max_length = int(
        prompt_size
        + start_pos
        + logits_processor.get_message_len() * max_new_tokens_ratio
    )
    max_length = min(max_length, tokenizer.model_max_length)
    min_length = min(min_length, max_length)

    output_tokens = model.generate(
        **tokenized_input,
        logits_processor=transformers.LogitsProcessorList([logits_processor]),
        min_length=min_length,
        max_length=max_length,
        do_sample=do_sample,
        num_beams=num_beams,
        repetition_penalty=float(repetition_penalty),
        pad_token_id=tokenizer.eos_token_id,
        generator=generator,
    )
    tokenizer.padding_side = "right"

    output_tokens = output_tokens[:, prompt_size:]
    output_texts = tokenizer.batch_decode(
        output_tokens, skip_special_tokens=True
    )
    output_tokens_post = tokenizer(
        output_texts,
        return_tensors="pt",
        add_special_tokens=False,
        padding=True,
    ).to(model.device)

    msg_rates, tokens_infos = logits_processor.validate(
        output_tokens_post.input_ids
    )

    return output_texts, msg_rates, tokens_infos


def decrypt(
    tokenizer,
    device: torch.device,
    text: str,
    msg_base: int,
    seed_scheme: str,
    window_length: int = 1,
    salt_key: Union[int, None] = None,
    private_key: Union[int, None] = None,
):
    """
    Extract the hidden data from the generated sequence.

    Args:
        tokenizer: tokenizer to use.
        text: text to decode.
        msg_base: base of the message.
        delta: bias added to scores of valid list.
        seed_scheme: scheme used to compute the seed.
        window_length: length of window to compute the seed.
        salt_key: salt to add to the seed.
        private_key: private key used to compute the seed.
    Returns:
        shifted versions of the message
    """
    tokenized_input = tokenizer(text, return_tensors="pt").to(device)

    decryptor = DecryptorProcessor(
        msg_base=msg_base,
        vocab=list(tokenizer.get_vocab().values()),
        device=device,
        seed_scheme=seed_scheme,
        window_length=window_length,
        salt_key=salt_key,
        private_key=private_key,
    )

    msg = decryptor.decrypt(tokenized_input.input_ids)

    return msg