#!/usr/bin/env python # An demo of linguistic steganography with patient-Huffman algorithm. # We use symmetric key cryptography to en/decrypt. # # Reference: # Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019. import nacl.secret import nacl.utils from transformers import GPT2TokenizerFast, GPT2LMHeadModel import gradio as gr import numpy as np import torch as th from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree # model_name = 'gpt2-xl' # XXX Use GPT-2-small for less compute model_name = 'gpt2' lm = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2TokenizerFast.from_pretrained(model_name) def bits_to_recover(max_plaintext_length): return (max_plaintext_length + 40) * 8 def p_next_token(prefix, cache=None, allow_eos=True): t_prefix = th.as_tensor(prefix) with th.no_grad(): if cache: # Incremental decoding. Input one token at a time with cache. lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache) else: lm_out = lm.forward(input_ids=t_prefix, use_cache=True) if allow_eos: # Assume EOS is the last token in the vocabulary. p_next_token = lm_out.logits[-1].softmax(dim=-1) else: p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1) return p_next_token.numpy(), lm_out.past_key_values def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400): '''We use a sequence of coin flips to control the generation of token indices from a language model. This returns _a sequence_ as defined by the language model, e.g. sentence, paragraph.''' # ind = tokenizer.bos_token_id # prefix = [ind] hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix) n_hidden_prefix_ind = len(hidden_prefix_ind) done_hiding = False p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding) n_skips = 0 n_bits_encoded = 0 n_tail_fill = 0 ind = None prefix_inds = [] # Terminate the generation after we generate the EOS token # XXX to save computation, we terminate as soon as all bits are hidden. while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id: # There is still some cipher text to hide if coin_flips: # Build Huffman codes for the conditional distribution heap = build_min_heap(p) hc = huffman_tree(heap) # print(hc) # Check if the total variation is low enough # print(len(prefix_inds) - 1, tv_huffman(hc, p)) # print(tv_huffman(hc, p)[0], tv_threshold) if tv_huffman(hc, p)[0] < tv_threshold: # Huffman-decode the cipher text into a token # Consume the cipher text until a token is generated decoder_state = hc while type(decoder_state) is tuple: left, right = decoder_state try: bit = coin_flips.pop(0) n_bits_encoded += 1 except IndexError: # No more cipher text. Pad with random bits bit = np.random.choice(2) n_tail_fill += 1 # 0 => left, 1 => right decoder_state = left if bit == '0' else right # Decoder settles in a leaf node ind = decoder_state prefix_inds.append(ind) yield prefix_inds done_hiding = not bool(coin_flips) p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding) continue # Forward sample according to LM normally n_skips += 1 if coin_flips else 0 ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p) prefix_inds.append(ind) yield prefix_inds p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding) # Drop the EOS index print(prefix_inds) print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill) if prefix_inds[-1] == tokenizer.eos_token_id: prefix_inds = prefix_inds[:-1] yield prefix_inds def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix): remaining_bits = bits_to_recover hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix) p, kv = p_next_token(hidden_prefix_inds, allow_eos=False) cipher_text = [] # Terminate the generation after we have consumed all indices or # have extracted all bits while token_inds and 0 < remaining_bits: # Build Huffman codes for the conditional distribution heap = build_min_heap(p) hc = huffman_tree(heap) # Check if the total variation is low enough if tv_huffman(hc, p)[0] < tv_threshold: # We have controlled this step. Some bits are hidden. code = invert_code_tree(hc) # Look up the Huffman code for the token. ind = token_inds.pop(0) # Convert the Huffman code into bits # left => 0, right => 1 cipher_text_fragment = code[ind] # Truncate possible trailing paddings cipher_text += cipher_text_fragment[:remaining_bits] remaining_bits -= len(cipher_text_fragment) yield cipher_text # print(remaining_bits) hidden_prefix_inds.append(ind) p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False) else: # We did not control this step. Skip. hidden_prefix_inds.append(token_inds.pop(0)) p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False) print(cipher_text, len(cipher_text), bits_to_recover) yield cipher_text with gr.Blocks() as demo: gr.Markdown(''' # Linguistic steganography demo with ``patient-Huffman`` algorithm Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message. ## Usage for message sender 1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text). 2. Click Hide to generate the stegotext/covertext. ## Usage for message receiver 1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext. 2. Click Decrypt to decipher the original message. ''') with gr.Accordion( 'Secrets shared between sender and receiver', open=False, ): # Shared secrets and parameters. gr.Markdown(''' - The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library. - An encryption key is randomly generated, you can refresh the page to get a different one. - The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example. - The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces. - Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext. Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019. ''') state = gr.State() with gr.Row(): tb_shared_key = gr.Textbox( label='encryption key (hex)', value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(), interactive=True, scale=1, lines=3, ) # dp_shared_lm = gr.Dropdown( # label='language model', # choices=[ # 'GPT-2', # # 'GPT-3', # ], # value='GPT-2', # ) s_shared_imp = gr.Slider( label='imperceptibility threshold', minimum=0, maximum=1, value=0.4, scale=1, ) s_shared_max_plaintext_len = gr.Slider( label='max plaintext length', minimum=4, maximum=32, step=1, value=18, scale=1, ) with gr.Column(scale=1): tb_shared_prefix = gr.Textbox( label='prefix', value='', ) gr.Examples( [ 'best dessert recipe: ', 'def solve(x):', 'breaking news ', '🤗🔒', ], tb_shared_prefix, cache_examples=False, ) with gr.Row(): with gr.Box(): with gr.Column(): # Sender gr.Markdown('## Sender') # Plain text tb_sender_plaintext = gr.Textbox( label='plaintext', value='gold in top drawer', ) btn_encrypt = gr.Button('🔒 Encrypt') # Encrypt # Cipher text tb_sender_ciphertext = gr.Textbox( label='ciphertext (hex)', ) btn_hide = gr.Button('🫣 Hide', interactive=False) # Hide # Cover text tb_sender_stegotext = gr.Textbox( label='stegotext', ) with gr.Box(): with gr.Column(): # Receiver gr.Markdown('## Receiver') # Cover text tb_receiver_stegotext = gr.Textbox( label='stegotext', ) btn_recover = gr.Button('🔎 Recover') # Cipher text tb_receiver_ciphertext = gr.Textbox( label='recovered ciphertext (hex)', ) btn_decrypt = gr.Button('🔓 Decrypt', interactive=True) # Plain text tb_receiver_plaintext = gr.Textbox( label='deciphered plaintext', ) gr.Markdown(''' ## Known issues 1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again. 2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help. ''') # Link the UI to handlers def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length): shared_key = bytes.fromhex(key_in_hex) # print(saved_state) if saved_state is None: # Create the secret boxes if they have not been created. sender_box = nacl.secret.SecretBox(shared_key) receiver_box = nacl.secret.SecretBox(shared_key) saved_state = sender_box, receiver_box else: sender_box, receiver_box = saved_state print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8'))) utf8_encoded_plaintext = bytes(plaintext, 'utf8') if len(utf8_encoded_plaintext) > max_plaintext_length: raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.') else: # Pad the plaintext to the maximum length. utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8') ciphertext = sender_box.encrypt(utf8_encoded_plaintext) print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex())) return [ saved_state, ciphertext.hex(), gr.Button.update(interactive=True), ] def decrypt(saved_state, ciphertext, key_in_hex): shared_key = bytes.fromhex(key_in_hex) if saved_state is None: # Create the secret boxes if they have not been created. sender_box = nacl.secret.SecretBox(shared_key) receiver_box = nacl.secret.SecretBox(shared_key) saved_state = sender_box, receiver_box else: sender_box, receiver_box = saved_state try: utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext)) print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext)) return [ saved_state, utf8_encoded_plaintext.decode('utf8'), ] except: raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.') def hide(ciphertext, tv_threshold, shared_prefix): # Convert hex to bits ba = bytes.fromhex(ciphertext) bits = [b for h in ba for b in f'{h:08b}'] print('Hide:', ciphertext, bits, len(bits)) embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2) for inds in embed_gen: yield tokenizer.decode(inds) def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix): inds = tokenizer.encode(stegotext) print('Recover:', stegotext, inds, len(inds)) n_bits_to_recover = bits_to_recover(max_plaintext_length) recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix) for bits in recover_gen: yield ''.join(bits) ba = bytearray() # Convert bits to bytearray for i in range(0, len(bits), 8): ba.append(int(''.join(bits[i:i+8]), 2)) yield ba.hex() btn_encrypt.click( encrypt, [state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len], [state, tb_sender_ciphertext, btn_hide], ) btn_hide.click( hide, [tb_sender_ciphertext, s_shared_imp, tb_shared_prefix], [tb_sender_stegotext], ) btn_recover.click( recover, [tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix], [tb_receiver_ciphertext], ) btn_decrypt.click( decrypt, [state, tb_receiver_ciphertext, tb_shared_key], [state, tb_receiver_plaintext], ) if __name__ == '__main__': demo.queue(concurrency_count=10) demo.launch() # demo.launch(share=True)