File size: 15,213 Bytes
178b66b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
#!/usr/bin/env python
# An demo of linguistic steganography with patient-Huffman algorithm.
# We use symmetric key cryptography to en/decrypt.
#
# Reference:
# Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019.

import nacl.secret
import nacl.utils
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import gradio as gr
import numpy as np
import torch as th

from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree

# model_name = 'gpt2-xl'
# XXX Use GPT-2-small for less compute
model_name = 'gpt2'
lm = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

def bits_to_recover(max_plaintext_length):
    return (max_plaintext_length + 40) * 8

def p_next_token(prefix, cache=None, allow_eos=True):
    t_prefix = th.as_tensor(prefix)
    with th.no_grad():
        if cache:
            # Incremental decoding. Input one token at a time with cache.
            lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache)
        else:
            lm_out = lm.forward(input_ids=t_prefix, use_cache=True)
        if allow_eos:
            # Assume EOS is the last token in the vocabulary.
            p_next_token = lm_out.logits[-1].softmax(dim=-1)
        else:
            p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1)
    return p_next_token.numpy(), lm_out.past_key_values

def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400):
    '''We use a sequence of coin flips to control the generation of token
    indices from a language model. This returns _a sequence_ as defined by
    the language model, e.g. sentence, paragraph.'''
    # ind = tokenizer.bos_token_id
    # prefix = [ind]

    hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
    n_hidden_prefix_ind = len(hidden_prefix_ind)
    done_hiding = False
    p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding)
    n_skips = 0
    n_bits_encoded = 0
    n_tail_fill = 0
    ind = None
    prefix_inds = []
    # Terminate the generation after we generate the EOS token
    # XXX to save computation, we terminate as soon as all bits are hidden.
    while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id:
        # There is still some cipher text to hide
        if coin_flips:
            # Build Huffman codes for the conditional distribution
            heap = build_min_heap(p)
            hc = huffman_tree(heap)
            # print(hc)
            # Check if the total variation is low enough
            # print(len(prefix_inds) - 1, tv_huffman(hc, p))
            # print(tv_huffman(hc, p)[0], tv_threshold)
            if tv_huffman(hc, p)[0] < tv_threshold:
                # Huffman-decode the cipher text into a token
                # Consume the cipher text until a token is generated
                decoder_state = hc
                while type(decoder_state) is tuple:
                    left, right = decoder_state
                    try:
                        bit = coin_flips.pop(0)
                        n_bits_encoded += 1
                    except IndexError:
                        # No more cipher text. Pad with random bits
                        bit = np.random.choice(2)
                        n_tail_fill += 1
                    # 0 => left, 1 => right
                    decoder_state = left if bit == '0' else right
                # Decoder settles in a leaf node
                ind = decoder_state
                prefix_inds.append(ind)
                yield prefix_inds
                done_hiding = not bool(coin_flips)
                p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
                continue
        # Forward sample according to LM normally
        n_skips += 1 if coin_flips else 0
        ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p)
        prefix_inds.append(ind)
        yield prefix_inds
        p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
    # Drop the EOS index
    print(prefix_inds)
    print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill)
    if prefix_inds[-1] == tokenizer.eos_token_id:
        prefix_inds = prefix_inds[:-1]
    yield prefix_inds

def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix):
    remaining_bits = bits_to_recover
    hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
    p, kv = p_next_token(hidden_prefix_inds, allow_eos=False)
    cipher_text = []
    # Terminate the generation after we have consumed all indices or
    # have extracted all bits
    while token_inds and 0 < remaining_bits:
        # Build Huffman codes for the conditional distribution
        heap = build_min_heap(p)
        hc = huffman_tree(heap)
        # Check if the total variation is low enough
        if tv_huffman(hc, p)[0] < tv_threshold:
            # We have controlled this step. Some bits are hidden.
            code = invert_code_tree(hc)
            # Look up the Huffman code for the token.
            ind = token_inds.pop(0)
            # Convert the Huffman code into bits
            # left => 0, right => 1
            cipher_text_fragment = code[ind]
            # Truncate possible trailing paddings
            cipher_text += cipher_text_fragment[:remaining_bits]
            remaining_bits -= len(cipher_text_fragment)
            yield cipher_text
            # print(remaining_bits)
            hidden_prefix_inds.append(ind)
            p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
        else:
            # We did not control this step. Skip.
            hidden_prefix_inds.append(token_inds.pop(0))
            p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
    print(cipher_text, len(cipher_text), bits_to_recover)
    yield cipher_text


with gr.Blocks() as demo:
    gr.Markdown('''
        # Linguistic steganography demo with ``patient-Huffman`` algorithm
        Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message.

        ## Usage for message sender
        1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text).
        2. Click Hide to generate the stegotext/covertext.

        ## Usage for message receiver
        1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext.
        2. Click Decrypt to decipher the original message.
    ''')

    with gr.Accordion(
        'Secrets shared between sender and receiver',
        open=False,
    ):
        # Shared secrets and parameters.
        gr.Markdown('''
        - The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library.
        - An encryption key is randomly generated, you can refresh the page to get a different one.
        - The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example.
        - The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces.
        - Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext.


        Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019.
        ''')
        state = gr.State()
        with gr.Row():
            tb_shared_key = gr.Textbox(
                label='encryption key (hex)',
                value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(),
                interactive=True,
                scale=1,
                lines=3,
            )
            # dp_shared_lm = gr.Dropdown(
            #     label='language model',
            #     choices=[
            #         'GPT-2',
            #         # 'GPT-3',
            #     ],
            #     value='GPT-2',
            # )
            s_shared_imp = gr.Slider(
                label='imperceptibility threshold',
                minimum=0,
                maximum=1,
                value=0.4,
                scale=1,
            )
            s_shared_max_plaintext_len = gr.Slider(
                label='max plaintext length',
                minimum=4,
                maximum=32,
                step=1,
                value=18,
                scale=1,
            )
            with gr.Column(scale=1):
                tb_shared_prefix = gr.Textbox(
                    label='prefix',
                    value='',
                )
                gr.Examples(
                    [
                        'best dessert recipe: ',
                        'def solve(x):',
                        'breaking news ',
                        'πŸ€—πŸ”’',
                    ],
                    tb_shared_prefix,
                    cache_examples=False,
                )

    with gr.Row():
        with gr.Box():
            with gr.Column():
                # Sender
                gr.Markdown('## Sender')

                # Plain text
                tb_sender_plaintext = gr.Textbox(
                    label='plaintext',
                    value='gold in top drawer',
                )
                btn_encrypt = gr.Button('πŸ”’ Encrypt')

                # Encrypt
                # Cipher text
                tb_sender_ciphertext = gr.Textbox(
                    label='ciphertext (hex)',
                )
                btn_hide = gr.Button('🫣 Hide', interactive=False)

                # Hide
                # Cover text
                tb_sender_stegotext = gr.Textbox(
                    label='stegotext',
                )

        with gr.Box():
            with gr.Column():
                # Receiver
                gr.Markdown('## Receiver')
                # Cover text
                tb_receiver_stegotext = gr.Textbox(
                    label='stegotext',
                )
                btn_recover = gr.Button('πŸ”Ž Recover')
                # Cipher text
                tb_receiver_ciphertext = gr.Textbox(
                    label='recovered ciphertext (hex)',
                )
                btn_decrypt = gr.Button('πŸ”“ Decrypt', interactive=True)
                # Plain text
                tb_receiver_plaintext = gr.Textbox(
                    label='deciphered plaintext',
                )

    gr.Markdown('''
        ## Known issues
        1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again.
        2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help.
    ''')

    # Link the UI to handlers
    def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length):
        shared_key = bytes.fromhex(key_in_hex)
        # print(saved_state)
        if saved_state is None:
            # Create the secret boxes if they have not been created.
            sender_box = nacl.secret.SecretBox(shared_key)
            receiver_box = nacl.secret.SecretBox(shared_key)
            saved_state = sender_box, receiver_box
        else:
            sender_box, receiver_box = saved_state
        print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8')))
        utf8_encoded_plaintext = bytes(plaintext, 'utf8')
        if len(utf8_encoded_plaintext) > max_plaintext_length:
            raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.')
        else:
            # Pad the plaintext to the maximum length.
            utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8')
        ciphertext = sender_box.encrypt(utf8_encoded_plaintext)
        print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex()))
        return [
            saved_state,
            ciphertext.hex(),
            gr.Button.update(interactive=True),
        ]

    def decrypt(saved_state, ciphertext, key_in_hex):
        shared_key = bytes.fromhex(key_in_hex)
        if saved_state is None:
            # Create the secret boxes if they have not been created.
            sender_box = nacl.secret.SecretBox(shared_key)
            receiver_box = nacl.secret.SecretBox(shared_key)
            saved_state = sender_box, receiver_box
        else:
            sender_box, receiver_box = saved_state
        try:
            utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext))
            print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext))
            return [
                saved_state,
                utf8_encoded_plaintext.decode('utf8'),
            ]
        except:
            raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.')

    def hide(ciphertext, tv_threshold, shared_prefix):
        # Convert hex to bits
        ba = bytes.fromhex(ciphertext)
        bits = [b for h in ba for b in f'{h:08b}']
        print('Hide:', ciphertext, bits, len(bits))
        embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2)
        for inds in embed_gen:
            yield tokenizer.decode(inds)

    def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix):
        inds = tokenizer.encode(stegotext)
        print('Recover:', stegotext, inds, len(inds))
        n_bits_to_recover = bits_to_recover(max_plaintext_length)
        recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix)
        for bits in recover_gen:
            yield ''.join(bits)
        ba = bytearray()
        # Convert bits to bytearray
        for i in range(0, len(bits), 8):
            ba.append(int(''.join(bits[i:i+8]), 2))
        yield ba.hex()

    btn_encrypt.click(
        encrypt,
        [state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len],
        [state, tb_sender_ciphertext, btn_hide],
    )
    btn_hide.click(
        hide,
        [tb_sender_ciphertext, s_shared_imp, tb_shared_prefix],
        [tb_sender_stegotext],
    )
    btn_recover.click(
        recover,
        [tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix],
        [tb_receiver_ciphertext],
    )
    btn_decrypt.click(
        decrypt,
        [state, tb_receiver_ciphertext, tb_shared_key],
        [state, tb_receiver_plaintext],
    )


if __name__ == '__main__':
    demo.queue(concurrency_count=10)
    demo.launch()
    # demo.launch(share=True)