stego-lm / app.py
dai
first release
178b66b
#!/usr/bin/env python
# An demo of linguistic steganography with patient-Huffman algorithm.
# We use symmetric key cryptography to en/decrypt.
#
# Reference:
# Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019.
import nacl.secret
import nacl.utils
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import gradio as gr
import numpy as np
import torch as th
from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
# model_name = 'gpt2-xl'
# XXX Use GPT-2-small for less compute
model_name = 'gpt2'
lm = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
def bits_to_recover(max_plaintext_length):
return (max_plaintext_length + 40) * 8
def p_next_token(prefix, cache=None, allow_eos=True):
t_prefix = th.as_tensor(prefix)
with th.no_grad():
if cache:
# Incremental decoding. Input one token at a time with cache.
lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache)
else:
lm_out = lm.forward(input_ids=t_prefix, use_cache=True)
if allow_eos:
# Assume EOS is the last token in the vocabulary.
p_next_token = lm_out.logits[-1].softmax(dim=-1)
else:
p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1)
return p_next_token.numpy(), lm_out.past_key_values
def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400):
'''We use a sequence of coin flips to control the generation of token
indices from a language model. This returns _a sequence_ as defined by
the language model, e.g. sentence, paragraph.'''
# ind = tokenizer.bos_token_id
# prefix = [ind]
hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
n_hidden_prefix_ind = len(hidden_prefix_ind)
done_hiding = False
p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding)
n_skips = 0
n_bits_encoded = 0
n_tail_fill = 0
ind = None
prefix_inds = []
# Terminate the generation after we generate the EOS token
# XXX to save computation, we terminate as soon as all bits are hidden.
while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id:
# There is still some cipher text to hide
if coin_flips:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# print(hc)
# Check if the total variation is low enough
# print(len(prefix_inds) - 1, tv_huffman(hc, p))
# print(tv_huffman(hc, p)[0], tv_threshold)
if tv_huffman(hc, p)[0] < tv_threshold:
# Huffman-decode the cipher text into a token
# Consume the cipher text until a token is generated
decoder_state = hc
while type(decoder_state) is tuple:
left, right = decoder_state
try:
bit = coin_flips.pop(0)
n_bits_encoded += 1
except IndexError:
# No more cipher text. Pad with random bits
bit = np.random.choice(2)
n_tail_fill += 1
# 0 => left, 1 => right
decoder_state = left if bit == '0' else right
# Decoder settles in a leaf node
ind = decoder_state
prefix_inds.append(ind)
yield prefix_inds
done_hiding = not bool(coin_flips)
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
continue
# Forward sample according to LM normally
n_skips += 1 if coin_flips else 0
ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p)
prefix_inds.append(ind)
yield prefix_inds
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
# Drop the EOS index
print(prefix_inds)
print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill)
if prefix_inds[-1] == tokenizer.eos_token_id:
prefix_inds = prefix_inds[:-1]
yield prefix_inds
def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix):
remaining_bits = bits_to_recover
hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
p, kv = p_next_token(hidden_prefix_inds, allow_eos=False)
cipher_text = []
# Terminate the generation after we have consumed all indices or
# have extracted all bits
while token_inds and 0 < remaining_bits:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
if tv_huffman(hc, p)[0] < tv_threshold:
# We have controlled this step. Some bits are hidden.
code = invert_code_tree(hc)
# Look up the Huffman code for the token.
ind = token_inds.pop(0)
# Convert the Huffman code into bits
# left => 0, right => 1
cipher_text_fragment = code[ind]
# Truncate possible trailing paddings
cipher_text += cipher_text_fragment[:remaining_bits]
remaining_bits -= len(cipher_text_fragment)
yield cipher_text
# print(remaining_bits)
hidden_prefix_inds.append(ind)
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
else:
# We did not control this step. Skip.
hidden_prefix_inds.append(token_inds.pop(0))
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
print(cipher_text, len(cipher_text), bits_to_recover)
yield cipher_text
with gr.Blocks() as demo:
gr.Markdown('''
# Linguistic steganography demo with ``patient-Huffman`` algorithm
Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message.
## Usage for message sender
1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text).
2. Click Hide to generate the stegotext/covertext.
## Usage for message receiver
1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext.
2. Click Decrypt to decipher the original message.
''')
with gr.Accordion(
'Secrets shared between sender and receiver',
open=False,
):
# Shared secrets and parameters.
gr.Markdown('''
- The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library.
- An encryption key is randomly generated, you can refresh the page to get a different one.
- The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example.
- The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces.
- Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext.
Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019.
''')
state = gr.State()
with gr.Row():
tb_shared_key = gr.Textbox(
label='encryption key (hex)',
value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(),
interactive=True,
scale=1,
lines=3,
)
# dp_shared_lm = gr.Dropdown(
# label='language model',
# choices=[
# 'GPT-2',
# # 'GPT-3',
# ],
# value='GPT-2',
# )
s_shared_imp = gr.Slider(
label='imperceptibility threshold',
minimum=0,
maximum=1,
value=0.4,
scale=1,
)
s_shared_max_plaintext_len = gr.Slider(
label='max plaintext length',
minimum=4,
maximum=32,
step=1,
value=18,
scale=1,
)
with gr.Column(scale=1):
tb_shared_prefix = gr.Textbox(
label='prefix',
value='',
)
gr.Examples(
[
'best dessert recipe: ',
'def solve(x):',
'breaking news ',
'πŸ€—πŸ”’',
],
tb_shared_prefix,
cache_examples=False,
)
with gr.Row():
with gr.Box():
with gr.Column():
# Sender
gr.Markdown('## Sender')
# Plain text
tb_sender_plaintext = gr.Textbox(
label='plaintext',
value='gold in top drawer',
)
btn_encrypt = gr.Button('πŸ”’ Encrypt')
# Encrypt
# Cipher text
tb_sender_ciphertext = gr.Textbox(
label='ciphertext (hex)',
)
btn_hide = gr.Button('🫣 Hide', interactive=False)
# Hide
# Cover text
tb_sender_stegotext = gr.Textbox(
label='stegotext',
)
with gr.Box():
with gr.Column():
# Receiver
gr.Markdown('## Receiver')
# Cover text
tb_receiver_stegotext = gr.Textbox(
label='stegotext',
)
btn_recover = gr.Button('πŸ”Ž Recover')
# Cipher text
tb_receiver_ciphertext = gr.Textbox(
label='recovered ciphertext (hex)',
)
btn_decrypt = gr.Button('πŸ”“ Decrypt', interactive=True)
# Plain text
tb_receiver_plaintext = gr.Textbox(
label='deciphered plaintext',
)
gr.Markdown('''
## Known issues
1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again.
2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help.
''')
# Link the UI to handlers
def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length):
shared_key = bytes.fromhex(key_in_hex)
# print(saved_state)
if saved_state is None:
# Create the secret boxes if they have not been created.
sender_box = nacl.secret.SecretBox(shared_key)
receiver_box = nacl.secret.SecretBox(shared_key)
saved_state = sender_box, receiver_box
else:
sender_box, receiver_box = saved_state
print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8')))
utf8_encoded_plaintext = bytes(plaintext, 'utf8')
if len(utf8_encoded_plaintext) > max_plaintext_length:
raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.')
else:
# Pad the plaintext to the maximum length.
utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8')
ciphertext = sender_box.encrypt(utf8_encoded_plaintext)
print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex()))
return [
saved_state,
ciphertext.hex(),
gr.Button.update(interactive=True),
]
def decrypt(saved_state, ciphertext, key_in_hex):
shared_key = bytes.fromhex(key_in_hex)
if saved_state is None:
# Create the secret boxes if they have not been created.
sender_box = nacl.secret.SecretBox(shared_key)
receiver_box = nacl.secret.SecretBox(shared_key)
saved_state = sender_box, receiver_box
else:
sender_box, receiver_box = saved_state
try:
utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext))
print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext))
return [
saved_state,
utf8_encoded_plaintext.decode('utf8'),
]
except:
raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.')
def hide(ciphertext, tv_threshold, shared_prefix):
# Convert hex to bits
ba = bytes.fromhex(ciphertext)
bits = [b for h in ba for b in f'{h:08b}']
print('Hide:', ciphertext, bits, len(bits))
embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2)
for inds in embed_gen:
yield tokenizer.decode(inds)
def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix):
inds = tokenizer.encode(stegotext)
print('Recover:', stegotext, inds, len(inds))
n_bits_to_recover = bits_to_recover(max_plaintext_length)
recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix)
for bits in recover_gen:
yield ''.join(bits)
ba = bytearray()
# Convert bits to bytearray
for i in range(0, len(bits), 8):
ba.append(int(''.join(bits[i:i+8]), 2))
yield ba.hex()
btn_encrypt.click(
encrypt,
[state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len],
[state, tb_sender_ciphertext, btn_hide],
)
btn_hide.click(
hide,
[tb_sender_ciphertext, s_shared_imp, tb_shared_prefix],
[tb_sender_stegotext],
)
btn_recover.click(
recover,
[tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix],
[tb_receiver_ciphertext],
)
btn_decrypt.click(
decrypt,
[state, tb_receiver_ciphertext, tb_shared_key],
[state, tb_receiver_plaintext],
)
if __name__ == '__main__':
demo.queue(concurrency_count=10)
demo.launch()
# demo.launch(share=True)