File size: 15,213 Bytes
178b66b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
#!/usr/bin/env python
# An demo of linguistic steganography with patient-Huffman algorithm.
# We use symmetric key cryptography to en/decrypt.
#
# Reference:
# Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019.
import nacl.secret
import nacl.utils
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import gradio as gr
import numpy as np
import torch as th
from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
# model_name = 'gpt2-xl'
# XXX Use GPT-2-small for less compute
model_name = 'gpt2'
lm = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
def bits_to_recover(max_plaintext_length):
return (max_plaintext_length + 40) * 8
def p_next_token(prefix, cache=None, allow_eos=True):
t_prefix = th.as_tensor(prefix)
with th.no_grad():
if cache:
# Incremental decoding. Input one token at a time with cache.
lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache)
else:
lm_out = lm.forward(input_ids=t_prefix, use_cache=True)
if allow_eos:
# Assume EOS is the last token in the vocabulary.
p_next_token = lm_out.logits[-1].softmax(dim=-1)
else:
p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1)
return p_next_token.numpy(), lm_out.past_key_values
def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400):
'''We use a sequence of coin flips to control the generation of token
indices from a language model. This returns _a sequence_ as defined by
the language model, e.g. sentence, paragraph.'''
# ind = tokenizer.bos_token_id
# prefix = [ind]
hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
n_hidden_prefix_ind = len(hidden_prefix_ind)
done_hiding = False
p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding)
n_skips = 0
n_bits_encoded = 0
n_tail_fill = 0
ind = None
prefix_inds = []
# Terminate the generation after we generate the EOS token
# XXX to save computation, we terminate as soon as all bits are hidden.
while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id:
# There is still some cipher text to hide
if coin_flips:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# print(hc)
# Check if the total variation is low enough
# print(len(prefix_inds) - 1, tv_huffman(hc, p))
# print(tv_huffman(hc, p)[0], tv_threshold)
if tv_huffman(hc, p)[0] < tv_threshold:
# Huffman-decode the cipher text into a token
# Consume the cipher text until a token is generated
decoder_state = hc
while type(decoder_state) is tuple:
left, right = decoder_state
try:
bit = coin_flips.pop(0)
n_bits_encoded += 1
except IndexError:
# No more cipher text. Pad with random bits
bit = np.random.choice(2)
n_tail_fill += 1
# 0 => left, 1 => right
decoder_state = left if bit == '0' else right
# Decoder settles in a leaf node
ind = decoder_state
prefix_inds.append(ind)
yield prefix_inds
done_hiding = not bool(coin_flips)
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
continue
# Forward sample according to LM normally
n_skips += 1 if coin_flips else 0
ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p)
prefix_inds.append(ind)
yield prefix_inds
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
# Drop the EOS index
print(prefix_inds)
print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill)
if prefix_inds[-1] == tokenizer.eos_token_id:
prefix_inds = prefix_inds[:-1]
yield prefix_inds
def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix):
remaining_bits = bits_to_recover
hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
p, kv = p_next_token(hidden_prefix_inds, allow_eos=False)
cipher_text = []
# Terminate the generation after we have consumed all indices or
# have extracted all bits
while token_inds and 0 < remaining_bits:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
if tv_huffman(hc, p)[0] < tv_threshold:
# We have controlled this step. Some bits are hidden.
code = invert_code_tree(hc)
# Look up the Huffman code for the token.
ind = token_inds.pop(0)
# Convert the Huffman code into bits
# left => 0, right => 1
cipher_text_fragment = code[ind]
# Truncate possible trailing paddings
cipher_text += cipher_text_fragment[:remaining_bits]
remaining_bits -= len(cipher_text_fragment)
yield cipher_text
# print(remaining_bits)
hidden_prefix_inds.append(ind)
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
else:
# We did not control this step. Skip.
hidden_prefix_inds.append(token_inds.pop(0))
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
print(cipher_text, len(cipher_text), bits_to_recover)
yield cipher_text
with gr.Blocks() as demo:
gr.Markdown('''
# Linguistic steganography demo with ``patient-Huffman`` algorithm
Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message.
## Usage for message sender
1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text).
2. Click Hide to generate the stegotext/covertext.
## Usage for message receiver
1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext.
2. Click Decrypt to decipher the original message.
''')
with gr.Accordion(
'Secrets shared between sender and receiver',
open=False,
):
# Shared secrets and parameters.
gr.Markdown('''
- The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library.
- An encryption key is randomly generated, you can refresh the page to get a different one.
- The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example.
- The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces.
- Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext.
Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019.
''')
state = gr.State()
with gr.Row():
tb_shared_key = gr.Textbox(
label='encryption key (hex)',
value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(),
interactive=True,
scale=1,
lines=3,
)
# dp_shared_lm = gr.Dropdown(
# label='language model',
# choices=[
# 'GPT-2',
# # 'GPT-3',
# ],
# value='GPT-2',
# )
s_shared_imp = gr.Slider(
label='imperceptibility threshold',
minimum=0,
maximum=1,
value=0.4,
scale=1,
)
s_shared_max_plaintext_len = gr.Slider(
label='max plaintext length',
minimum=4,
maximum=32,
step=1,
value=18,
scale=1,
)
with gr.Column(scale=1):
tb_shared_prefix = gr.Textbox(
label='prefix',
value='',
)
gr.Examples(
[
'best dessert recipe: ',
'def solve(x):',
'breaking news ',
'π€π',
],
tb_shared_prefix,
cache_examples=False,
)
with gr.Row():
with gr.Box():
with gr.Column():
# Sender
gr.Markdown('## Sender')
# Plain text
tb_sender_plaintext = gr.Textbox(
label='plaintext',
value='gold in top drawer',
)
btn_encrypt = gr.Button('π Encrypt')
# Encrypt
# Cipher text
tb_sender_ciphertext = gr.Textbox(
label='ciphertext (hex)',
)
btn_hide = gr.Button('π«£ Hide', interactive=False)
# Hide
# Cover text
tb_sender_stegotext = gr.Textbox(
label='stegotext',
)
with gr.Box():
with gr.Column():
# Receiver
gr.Markdown('## Receiver')
# Cover text
tb_receiver_stegotext = gr.Textbox(
label='stegotext',
)
btn_recover = gr.Button('π Recover')
# Cipher text
tb_receiver_ciphertext = gr.Textbox(
label='recovered ciphertext (hex)',
)
btn_decrypt = gr.Button('π Decrypt', interactive=True)
# Plain text
tb_receiver_plaintext = gr.Textbox(
label='deciphered plaintext',
)
gr.Markdown('''
## Known issues
1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again.
2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help.
''')
# Link the UI to handlers
def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length):
shared_key = bytes.fromhex(key_in_hex)
# print(saved_state)
if saved_state is None:
# Create the secret boxes if they have not been created.
sender_box = nacl.secret.SecretBox(shared_key)
receiver_box = nacl.secret.SecretBox(shared_key)
saved_state = sender_box, receiver_box
else:
sender_box, receiver_box = saved_state
print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8')))
utf8_encoded_plaintext = bytes(plaintext, 'utf8')
if len(utf8_encoded_plaintext) > max_plaintext_length:
raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.')
else:
# Pad the plaintext to the maximum length.
utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8')
ciphertext = sender_box.encrypt(utf8_encoded_plaintext)
print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex()))
return [
saved_state,
ciphertext.hex(),
gr.Button.update(interactive=True),
]
def decrypt(saved_state, ciphertext, key_in_hex):
shared_key = bytes.fromhex(key_in_hex)
if saved_state is None:
# Create the secret boxes if they have not been created.
sender_box = nacl.secret.SecretBox(shared_key)
receiver_box = nacl.secret.SecretBox(shared_key)
saved_state = sender_box, receiver_box
else:
sender_box, receiver_box = saved_state
try:
utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext))
print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext))
return [
saved_state,
utf8_encoded_plaintext.decode('utf8'),
]
except:
raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.')
def hide(ciphertext, tv_threshold, shared_prefix):
# Convert hex to bits
ba = bytes.fromhex(ciphertext)
bits = [b for h in ba for b in f'{h:08b}']
print('Hide:', ciphertext, bits, len(bits))
embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2)
for inds in embed_gen:
yield tokenizer.decode(inds)
def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix):
inds = tokenizer.encode(stegotext)
print('Recover:', stegotext, inds, len(inds))
n_bits_to_recover = bits_to_recover(max_plaintext_length)
recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix)
for bits in recover_gen:
yield ''.join(bits)
ba = bytearray()
# Convert bits to bytearray
for i in range(0, len(bits), 8):
ba.append(int(''.join(bits[i:i+8]), 2))
yield ba.hex()
btn_encrypt.click(
encrypt,
[state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len],
[state, tb_sender_ciphertext, btn_hide],
)
btn_hide.click(
hide,
[tb_sender_ciphertext, s_shared_imp, tb_shared_prefix],
[tb_sender_stegotext],
)
btn_recover.click(
recover,
[tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix],
[tb_receiver_ciphertext],
)
btn_decrypt.click(
decrypt,
[state, tb_receiver_ciphertext, tb_shared_key],
[state, tb_receiver_plaintext],
)
if __name__ == '__main__':
demo.queue(concurrency_count=10)
demo.launch()
# demo.launch(share=True) |