|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import nacl.secret |
|
import nacl.utils |
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel |
|
import gradio as gr |
|
import numpy as np |
|
import torch as th |
|
|
|
from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree |
|
|
|
|
|
|
|
model_name = 'gpt2' |
|
lm = GPT2LMHeadModel.from_pretrained(model_name) |
|
tokenizer = GPT2TokenizerFast.from_pretrained(model_name) |
|
|
|
def bits_to_recover(max_plaintext_length): |
|
return (max_plaintext_length + 40) * 8 |
|
|
|
def p_next_token(prefix, cache=None, allow_eos=True): |
|
t_prefix = th.as_tensor(prefix) |
|
with th.no_grad(): |
|
if cache: |
|
|
|
lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache) |
|
else: |
|
lm_out = lm.forward(input_ids=t_prefix, use_cache=True) |
|
if allow_eos: |
|
|
|
p_next_token = lm_out.logits[-1].softmax(dim=-1) |
|
else: |
|
p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1) |
|
return p_next_token.numpy(), lm_out.past_key_values |
|
|
|
def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400): |
|
'''We use a sequence of coin flips to control the generation of token |
|
indices from a language model. This returns _a sequence_ as defined by |
|
the language model, e.g. sentence, paragraph.''' |
|
|
|
|
|
|
|
hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix) |
|
n_hidden_prefix_ind = len(hidden_prefix_ind) |
|
done_hiding = False |
|
p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding) |
|
n_skips = 0 |
|
n_bits_encoded = 0 |
|
n_tail_fill = 0 |
|
ind = None |
|
prefix_inds = [] |
|
|
|
|
|
while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id: |
|
|
|
if coin_flips: |
|
|
|
heap = build_min_heap(p) |
|
hc = huffman_tree(heap) |
|
|
|
|
|
|
|
|
|
if tv_huffman(hc, p)[0] < tv_threshold: |
|
|
|
|
|
decoder_state = hc |
|
while type(decoder_state) is tuple: |
|
left, right = decoder_state |
|
try: |
|
bit = coin_flips.pop(0) |
|
n_bits_encoded += 1 |
|
except IndexError: |
|
|
|
bit = np.random.choice(2) |
|
n_tail_fill += 1 |
|
|
|
decoder_state = left if bit == '0' else right |
|
|
|
ind = decoder_state |
|
prefix_inds.append(ind) |
|
yield prefix_inds |
|
done_hiding = not bool(coin_flips) |
|
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding) |
|
continue |
|
|
|
n_skips += 1 if coin_flips else 0 |
|
ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p) |
|
prefix_inds.append(ind) |
|
yield prefix_inds |
|
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding) |
|
|
|
print(prefix_inds) |
|
print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill) |
|
if prefix_inds[-1] == tokenizer.eos_token_id: |
|
prefix_inds = prefix_inds[:-1] |
|
yield prefix_inds |
|
|
|
def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix): |
|
remaining_bits = bits_to_recover |
|
hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix) |
|
p, kv = p_next_token(hidden_prefix_inds, allow_eos=False) |
|
cipher_text = [] |
|
|
|
|
|
while token_inds and 0 < remaining_bits: |
|
|
|
heap = build_min_heap(p) |
|
hc = huffman_tree(heap) |
|
|
|
if tv_huffman(hc, p)[0] < tv_threshold: |
|
|
|
code = invert_code_tree(hc) |
|
|
|
ind = token_inds.pop(0) |
|
|
|
|
|
cipher_text_fragment = code[ind] |
|
|
|
cipher_text += cipher_text_fragment[:remaining_bits] |
|
remaining_bits -= len(cipher_text_fragment) |
|
yield cipher_text |
|
|
|
hidden_prefix_inds.append(ind) |
|
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False) |
|
else: |
|
|
|
hidden_prefix_inds.append(token_inds.pop(0)) |
|
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False) |
|
print(cipher_text, len(cipher_text), bits_to_recover) |
|
yield cipher_text |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(''' |
|
# Linguistic steganography demo with ``patient-Huffman`` algorithm |
|
Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message. |
|
|
|
## Usage for message sender |
|
1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text). |
|
2. Click Hide to generate the stegotext/covertext. |
|
|
|
## Usage for message receiver |
|
1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext. |
|
2. Click Decrypt to decipher the original message. |
|
''') |
|
|
|
with gr.Accordion( |
|
'Secrets shared between sender and receiver', |
|
open=False, |
|
): |
|
|
|
gr.Markdown(''' |
|
- The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library. |
|
- An encryption key is randomly generated, you can refresh the page to get a different one. |
|
- The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example. |
|
- The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces. |
|
- Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext. |
|
|
|
|
|
Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019. |
|
''') |
|
state = gr.State() |
|
with gr.Row(): |
|
tb_shared_key = gr.Textbox( |
|
label='encryption key (hex)', |
|
value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(), |
|
interactive=True, |
|
scale=1, |
|
lines=3, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s_shared_imp = gr.Slider( |
|
label='imperceptibility threshold', |
|
minimum=0, |
|
maximum=1, |
|
value=0.4, |
|
scale=1, |
|
) |
|
s_shared_max_plaintext_len = gr.Slider( |
|
label='max plaintext length', |
|
minimum=4, |
|
maximum=32, |
|
step=1, |
|
value=18, |
|
scale=1, |
|
) |
|
with gr.Column(scale=1): |
|
tb_shared_prefix = gr.Textbox( |
|
label='prefix', |
|
value='', |
|
) |
|
gr.Examples( |
|
[ |
|
'best dessert recipe: ', |
|
'def solve(x):', |
|
'breaking news ', |
|
'π€π', |
|
], |
|
tb_shared_prefix, |
|
cache_examples=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Box(): |
|
with gr.Column(): |
|
|
|
gr.Markdown('## Sender') |
|
|
|
|
|
tb_sender_plaintext = gr.Textbox( |
|
label='plaintext', |
|
value='gold in top drawer', |
|
) |
|
btn_encrypt = gr.Button('π Encrypt') |
|
|
|
|
|
|
|
tb_sender_ciphertext = gr.Textbox( |
|
label='ciphertext (hex)', |
|
) |
|
btn_hide = gr.Button('π«£ Hide', interactive=False) |
|
|
|
|
|
|
|
tb_sender_stegotext = gr.Textbox( |
|
label='stegotext', |
|
) |
|
|
|
with gr.Box(): |
|
with gr.Column(): |
|
|
|
gr.Markdown('## Receiver') |
|
|
|
tb_receiver_stegotext = gr.Textbox( |
|
label='stegotext', |
|
) |
|
btn_recover = gr.Button('π Recover') |
|
|
|
tb_receiver_ciphertext = gr.Textbox( |
|
label='recovered ciphertext (hex)', |
|
) |
|
btn_decrypt = gr.Button('π Decrypt', interactive=True) |
|
|
|
tb_receiver_plaintext = gr.Textbox( |
|
label='deciphered plaintext', |
|
) |
|
|
|
gr.Markdown(''' |
|
## Known issues |
|
1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again. |
|
2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help. |
|
''') |
|
|
|
|
|
def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length): |
|
shared_key = bytes.fromhex(key_in_hex) |
|
|
|
if saved_state is None: |
|
|
|
sender_box = nacl.secret.SecretBox(shared_key) |
|
receiver_box = nacl.secret.SecretBox(shared_key) |
|
saved_state = sender_box, receiver_box |
|
else: |
|
sender_box, receiver_box = saved_state |
|
print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8'))) |
|
utf8_encoded_plaintext = bytes(plaintext, 'utf8') |
|
if len(utf8_encoded_plaintext) > max_plaintext_length: |
|
raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.') |
|
else: |
|
|
|
utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8') |
|
ciphertext = sender_box.encrypt(utf8_encoded_plaintext) |
|
print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex())) |
|
return [ |
|
saved_state, |
|
ciphertext.hex(), |
|
gr.Button.update(interactive=True), |
|
] |
|
|
|
def decrypt(saved_state, ciphertext, key_in_hex): |
|
shared_key = bytes.fromhex(key_in_hex) |
|
if saved_state is None: |
|
|
|
sender_box = nacl.secret.SecretBox(shared_key) |
|
receiver_box = nacl.secret.SecretBox(shared_key) |
|
saved_state = sender_box, receiver_box |
|
else: |
|
sender_box, receiver_box = saved_state |
|
try: |
|
utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext)) |
|
print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext)) |
|
return [ |
|
saved_state, |
|
utf8_encoded_plaintext.decode('utf8'), |
|
] |
|
except: |
|
raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.') |
|
|
|
def hide(ciphertext, tv_threshold, shared_prefix): |
|
|
|
ba = bytes.fromhex(ciphertext) |
|
bits = [b for h in ba for b in f'{h:08b}'] |
|
print('Hide:', ciphertext, bits, len(bits)) |
|
embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2) |
|
for inds in embed_gen: |
|
yield tokenizer.decode(inds) |
|
|
|
def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix): |
|
inds = tokenizer.encode(stegotext) |
|
print('Recover:', stegotext, inds, len(inds)) |
|
n_bits_to_recover = bits_to_recover(max_plaintext_length) |
|
recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix) |
|
for bits in recover_gen: |
|
yield ''.join(bits) |
|
ba = bytearray() |
|
|
|
for i in range(0, len(bits), 8): |
|
ba.append(int(''.join(bits[i:i+8]), 2)) |
|
yield ba.hex() |
|
|
|
btn_encrypt.click( |
|
encrypt, |
|
[state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len], |
|
[state, tb_sender_ciphertext, btn_hide], |
|
) |
|
btn_hide.click( |
|
hide, |
|
[tb_sender_ciphertext, s_shared_imp, tb_shared_prefix], |
|
[tb_sender_stegotext], |
|
) |
|
btn_recover.click( |
|
recover, |
|
[tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix], |
|
[tb_receiver_ciphertext], |
|
) |
|
btn_decrypt.click( |
|
decrypt, |
|
[state, tb_receiver_ciphertext, tb_shared_key], |
|
[state, tb_receiver_plaintext], |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
demo.queue(concurrency_count=10) |
|
demo.launch() |
|
|