dai
commited on
Commit
Β·
178b66b
1
Parent(s):
b3fbbe5
first release
Browse files- README.md +4 -4
- app.py +359 -0
- huffman.py +181 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title: Stego
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: openrail
|
11 |
---
|
12 |
|
13 |
-
|
|
|
1 |
---
|
2 |
+
title: Stego LM
|
3 |
+
emoji: πππ
|
4 |
colorFrom: indigo
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.47.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: openrail
|
11 |
---
|
12 |
|
13 |
+
Hide the hiding.
|
app.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# An demo of linguistic steganography with patient-Huffman algorithm.
|
3 |
+
# We use symmetric key cryptography to en/decrypt.
|
4 |
+
#
|
5 |
+
# Reference:
|
6 |
+
# Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019.
|
7 |
+
|
8 |
+
import nacl.secret
|
9 |
+
import nacl.utils
|
10 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch as th
|
14 |
+
|
15 |
+
from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
|
16 |
+
|
17 |
+
# model_name = 'gpt2-xl'
|
18 |
+
# XXX Use GPT-2-small for less compute
|
19 |
+
model_name = 'gpt2'
|
20 |
+
lm = GPT2LMHeadModel.from_pretrained(model_name)
|
21 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
|
22 |
+
|
23 |
+
def bits_to_recover(max_plaintext_length):
|
24 |
+
return (max_plaintext_length + 40) * 8
|
25 |
+
|
26 |
+
def p_next_token(prefix, cache=None, allow_eos=True):
|
27 |
+
t_prefix = th.as_tensor(prefix)
|
28 |
+
with th.no_grad():
|
29 |
+
if cache:
|
30 |
+
# Incremental decoding. Input one token at a time with cache.
|
31 |
+
lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache)
|
32 |
+
else:
|
33 |
+
lm_out = lm.forward(input_ids=t_prefix, use_cache=True)
|
34 |
+
if allow_eos:
|
35 |
+
# Assume EOS is the last token in the vocabulary.
|
36 |
+
p_next_token = lm_out.logits[-1].softmax(dim=-1)
|
37 |
+
else:
|
38 |
+
p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1)
|
39 |
+
return p_next_token.numpy(), lm_out.past_key_values
|
40 |
+
|
41 |
+
def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400):
|
42 |
+
'''We use a sequence of coin flips to control the generation of token
|
43 |
+
indices from a language model. This returns _a sequence_ as defined by
|
44 |
+
the language model, e.g. sentence, paragraph.'''
|
45 |
+
# ind = tokenizer.bos_token_id
|
46 |
+
# prefix = [ind]
|
47 |
+
|
48 |
+
hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
|
49 |
+
n_hidden_prefix_ind = len(hidden_prefix_ind)
|
50 |
+
done_hiding = False
|
51 |
+
p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding)
|
52 |
+
n_skips = 0
|
53 |
+
n_bits_encoded = 0
|
54 |
+
n_tail_fill = 0
|
55 |
+
ind = None
|
56 |
+
prefix_inds = []
|
57 |
+
# Terminate the generation after we generate the EOS token
|
58 |
+
# XXX to save computation, we terminate as soon as all bits are hidden.
|
59 |
+
while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id:
|
60 |
+
# There is still some cipher text to hide
|
61 |
+
if coin_flips:
|
62 |
+
# Build Huffman codes for the conditional distribution
|
63 |
+
heap = build_min_heap(p)
|
64 |
+
hc = huffman_tree(heap)
|
65 |
+
# print(hc)
|
66 |
+
# Check if the total variation is low enough
|
67 |
+
# print(len(prefix_inds) - 1, tv_huffman(hc, p))
|
68 |
+
# print(tv_huffman(hc, p)[0], tv_threshold)
|
69 |
+
if tv_huffman(hc, p)[0] < tv_threshold:
|
70 |
+
# Huffman-decode the cipher text into a token
|
71 |
+
# Consume the cipher text until a token is generated
|
72 |
+
decoder_state = hc
|
73 |
+
while type(decoder_state) is tuple:
|
74 |
+
left, right = decoder_state
|
75 |
+
try:
|
76 |
+
bit = coin_flips.pop(0)
|
77 |
+
n_bits_encoded += 1
|
78 |
+
except IndexError:
|
79 |
+
# No more cipher text. Pad with random bits
|
80 |
+
bit = np.random.choice(2)
|
81 |
+
n_tail_fill += 1
|
82 |
+
# 0 => left, 1 => right
|
83 |
+
decoder_state = left if bit == '0' else right
|
84 |
+
# Decoder settles in a leaf node
|
85 |
+
ind = decoder_state
|
86 |
+
prefix_inds.append(ind)
|
87 |
+
yield prefix_inds
|
88 |
+
done_hiding = not bool(coin_flips)
|
89 |
+
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
|
90 |
+
continue
|
91 |
+
# Forward sample according to LM normally
|
92 |
+
n_skips += 1 if coin_flips else 0
|
93 |
+
ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p)
|
94 |
+
prefix_inds.append(ind)
|
95 |
+
yield prefix_inds
|
96 |
+
p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
|
97 |
+
# Drop the EOS index
|
98 |
+
print(prefix_inds)
|
99 |
+
print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill)
|
100 |
+
if prefix_inds[-1] == tokenizer.eos_token_id:
|
101 |
+
prefix_inds = prefix_inds[:-1]
|
102 |
+
yield prefix_inds
|
103 |
+
|
104 |
+
def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix):
|
105 |
+
remaining_bits = bits_to_recover
|
106 |
+
hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
|
107 |
+
p, kv = p_next_token(hidden_prefix_inds, allow_eos=False)
|
108 |
+
cipher_text = []
|
109 |
+
# Terminate the generation after we have consumed all indices or
|
110 |
+
# have extracted all bits
|
111 |
+
while token_inds and 0 < remaining_bits:
|
112 |
+
# Build Huffman codes for the conditional distribution
|
113 |
+
heap = build_min_heap(p)
|
114 |
+
hc = huffman_tree(heap)
|
115 |
+
# Check if the total variation is low enough
|
116 |
+
if tv_huffman(hc, p)[0] < tv_threshold:
|
117 |
+
# We have controlled this step. Some bits are hidden.
|
118 |
+
code = invert_code_tree(hc)
|
119 |
+
# Look up the Huffman code for the token.
|
120 |
+
ind = token_inds.pop(0)
|
121 |
+
# Convert the Huffman code into bits
|
122 |
+
# left => 0, right => 1
|
123 |
+
cipher_text_fragment = code[ind]
|
124 |
+
# Truncate possible trailing paddings
|
125 |
+
cipher_text += cipher_text_fragment[:remaining_bits]
|
126 |
+
remaining_bits -= len(cipher_text_fragment)
|
127 |
+
yield cipher_text
|
128 |
+
# print(remaining_bits)
|
129 |
+
hidden_prefix_inds.append(ind)
|
130 |
+
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
|
131 |
+
else:
|
132 |
+
# We did not control this step. Skip.
|
133 |
+
hidden_prefix_inds.append(token_inds.pop(0))
|
134 |
+
p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
|
135 |
+
print(cipher_text, len(cipher_text), bits_to_recover)
|
136 |
+
yield cipher_text
|
137 |
+
|
138 |
+
|
139 |
+
with gr.Blocks() as demo:
|
140 |
+
gr.Markdown('''
|
141 |
+
# Linguistic steganography demo with ``patient-Huffman`` algorithm
|
142 |
+
Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message.
|
143 |
+
|
144 |
+
## Usage for message sender
|
145 |
+
1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text).
|
146 |
+
2. Click Hide to generate the stegotext/covertext.
|
147 |
+
|
148 |
+
## Usage for message receiver
|
149 |
+
1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext.
|
150 |
+
2. Click Decrypt to decipher the original message.
|
151 |
+
''')
|
152 |
+
|
153 |
+
with gr.Accordion(
|
154 |
+
'Secrets shared between sender and receiver',
|
155 |
+
open=False,
|
156 |
+
):
|
157 |
+
# Shared secrets and parameters.
|
158 |
+
gr.Markdown('''
|
159 |
+
- The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library.
|
160 |
+
- An encryption key is randomly generated, you can refresh the page to get a different one.
|
161 |
+
- The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example.
|
162 |
+
- The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces.
|
163 |
+
- Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext.
|
164 |
+
|
165 |
+
|
166 |
+
Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019.
|
167 |
+
''')
|
168 |
+
state = gr.State()
|
169 |
+
with gr.Row():
|
170 |
+
tb_shared_key = gr.Textbox(
|
171 |
+
label='encryption key (hex)',
|
172 |
+
value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(),
|
173 |
+
interactive=True,
|
174 |
+
scale=1,
|
175 |
+
lines=3,
|
176 |
+
)
|
177 |
+
# dp_shared_lm = gr.Dropdown(
|
178 |
+
# label='language model',
|
179 |
+
# choices=[
|
180 |
+
# 'GPT-2',
|
181 |
+
# # 'GPT-3',
|
182 |
+
# ],
|
183 |
+
# value='GPT-2',
|
184 |
+
# )
|
185 |
+
s_shared_imp = gr.Slider(
|
186 |
+
label='imperceptibility threshold',
|
187 |
+
minimum=0,
|
188 |
+
maximum=1,
|
189 |
+
value=0.4,
|
190 |
+
scale=1,
|
191 |
+
)
|
192 |
+
s_shared_max_plaintext_len = gr.Slider(
|
193 |
+
label='max plaintext length',
|
194 |
+
minimum=4,
|
195 |
+
maximum=32,
|
196 |
+
step=1,
|
197 |
+
value=18,
|
198 |
+
scale=1,
|
199 |
+
)
|
200 |
+
with gr.Column(scale=1):
|
201 |
+
tb_shared_prefix = gr.Textbox(
|
202 |
+
label='prefix',
|
203 |
+
value='',
|
204 |
+
)
|
205 |
+
gr.Examples(
|
206 |
+
[
|
207 |
+
'best dessert recipe: ',
|
208 |
+
'def solve(x):',
|
209 |
+
'breaking news ',
|
210 |
+
'π€π',
|
211 |
+
],
|
212 |
+
tb_shared_prefix,
|
213 |
+
cache_examples=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
with gr.Row():
|
217 |
+
with gr.Box():
|
218 |
+
with gr.Column():
|
219 |
+
# Sender
|
220 |
+
gr.Markdown('## Sender')
|
221 |
+
|
222 |
+
# Plain text
|
223 |
+
tb_sender_plaintext = gr.Textbox(
|
224 |
+
label='plaintext',
|
225 |
+
value='gold in top drawer',
|
226 |
+
)
|
227 |
+
btn_encrypt = gr.Button('π Encrypt')
|
228 |
+
|
229 |
+
# Encrypt
|
230 |
+
# Cipher text
|
231 |
+
tb_sender_ciphertext = gr.Textbox(
|
232 |
+
label='ciphertext (hex)',
|
233 |
+
)
|
234 |
+
btn_hide = gr.Button('π«£ Hide', interactive=False)
|
235 |
+
|
236 |
+
# Hide
|
237 |
+
# Cover text
|
238 |
+
tb_sender_stegotext = gr.Textbox(
|
239 |
+
label='stegotext',
|
240 |
+
)
|
241 |
+
|
242 |
+
with gr.Box():
|
243 |
+
with gr.Column():
|
244 |
+
# Receiver
|
245 |
+
gr.Markdown('## Receiver')
|
246 |
+
# Cover text
|
247 |
+
tb_receiver_stegotext = gr.Textbox(
|
248 |
+
label='stegotext',
|
249 |
+
)
|
250 |
+
btn_recover = gr.Button('π Recover')
|
251 |
+
# Cipher text
|
252 |
+
tb_receiver_ciphertext = gr.Textbox(
|
253 |
+
label='recovered ciphertext (hex)',
|
254 |
+
)
|
255 |
+
btn_decrypt = gr.Button('π Decrypt', interactive=True)
|
256 |
+
# Plain text
|
257 |
+
tb_receiver_plaintext = gr.Textbox(
|
258 |
+
label='deciphered plaintext',
|
259 |
+
)
|
260 |
+
|
261 |
+
gr.Markdown('''
|
262 |
+
## Known issues
|
263 |
+
1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again.
|
264 |
+
2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help.
|
265 |
+
''')
|
266 |
+
|
267 |
+
# Link the UI to handlers
|
268 |
+
def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length):
|
269 |
+
shared_key = bytes.fromhex(key_in_hex)
|
270 |
+
# print(saved_state)
|
271 |
+
if saved_state is None:
|
272 |
+
# Create the secret boxes if they have not been created.
|
273 |
+
sender_box = nacl.secret.SecretBox(shared_key)
|
274 |
+
receiver_box = nacl.secret.SecretBox(shared_key)
|
275 |
+
saved_state = sender_box, receiver_box
|
276 |
+
else:
|
277 |
+
sender_box, receiver_box = saved_state
|
278 |
+
print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8')))
|
279 |
+
utf8_encoded_plaintext = bytes(plaintext, 'utf8')
|
280 |
+
if len(utf8_encoded_plaintext) > max_plaintext_length:
|
281 |
+
raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.')
|
282 |
+
else:
|
283 |
+
# Pad the plaintext to the maximum length.
|
284 |
+
utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8')
|
285 |
+
ciphertext = sender_box.encrypt(utf8_encoded_plaintext)
|
286 |
+
print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex()))
|
287 |
+
return [
|
288 |
+
saved_state,
|
289 |
+
ciphertext.hex(),
|
290 |
+
gr.Button.update(interactive=True),
|
291 |
+
]
|
292 |
+
|
293 |
+
def decrypt(saved_state, ciphertext, key_in_hex):
|
294 |
+
shared_key = bytes.fromhex(key_in_hex)
|
295 |
+
if saved_state is None:
|
296 |
+
# Create the secret boxes if they have not been created.
|
297 |
+
sender_box = nacl.secret.SecretBox(shared_key)
|
298 |
+
receiver_box = nacl.secret.SecretBox(shared_key)
|
299 |
+
saved_state = sender_box, receiver_box
|
300 |
+
else:
|
301 |
+
sender_box, receiver_box = saved_state
|
302 |
+
try:
|
303 |
+
utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext))
|
304 |
+
print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext))
|
305 |
+
return [
|
306 |
+
saved_state,
|
307 |
+
utf8_encoded_plaintext.decode('utf8'),
|
308 |
+
]
|
309 |
+
except:
|
310 |
+
raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.')
|
311 |
+
|
312 |
+
def hide(ciphertext, tv_threshold, shared_prefix):
|
313 |
+
# Convert hex to bits
|
314 |
+
ba = bytes.fromhex(ciphertext)
|
315 |
+
bits = [b for h in ba for b in f'{h:08b}']
|
316 |
+
print('Hide:', ciphertext, bits, len(bits))
|
317 |
+
embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2)
|
318 |
+
for inds in embed_gen:
|
319 |
+
yield tokenizer.decode(inds)
|
320 |
+
|
321 |
+
def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix):
|
322 |
+
inds = tokenizer.encode(stegotext)
|
323 |
+
print('Recover:', stegotext, inds, len(inds))
|
324 |
+
n_bits_to_recover = bits_to_recover(max_plaintext_length)
|
325 |
+
recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix)
|
326 |
+
for bits in recover_gen:
|
327 |
+
yield ''.join(bits)
|
328 |
+
ba = bytearray()
|
329 |
+
# Convert bits to bytearray
|
330 |
+
for i in range(0, len(bits), 8):
|
331 |
+
ba.append(int(''.join(bits[i:i+8]), 2))
|
332 |
+
yield ba.hex()
|
333 |
+
|
334 |
+
btn_encrypt.click(
|
335 |
+
encrypt,
|
336 |
+
[state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len],
|
337 |
+
[state, tb_sender_ciphertext, btn_hide],
|
338 |
+
)
|
339 |
+
btn_hide.click(
|
340 |
+
hide,
|
341 |
+
[tb_sender_ciphertext, s_shared_imp, tb_shared_prefix],
|
342 |
+
[tb_sender_stegotext],
|
343 |
+
)
|
344 |
+
btn_recover.click(
|
345 |
+
recover,
|
346 |
+
[tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix],
|
347 |
+
[tb_receiver_ciphertext],
|
348 |
+
)
|
349 |
+
btn_decrypt.click(
|
350 |
+
decrypt,
|
351 |
+
[state, tb_receiver_ciphertext, tb_shared_key],
|
352 |
+
[state, tb_receiver_plaintext],
|
353 |
+
)
|
354 |
+
|
355 |
+
|
356 |
+
if __name__ == '__main__':
|
357 |
+
demo.queue(concurrency_count=10)
|
358 |
+
demo.launch()
|
359 |
+
# demo.launch(share=True)
|
huffman.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import heapq
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def build_min_heap(freqs, inds=None):
|
7 |
+
'''Returns a min-heap of (frequency, token_index).'''
|
8 |
+
inds = inds or range(len(freqs))
|
9 |
+
# Add a counter in tuples for tiebreaking
|
10 |
+
freq_index = [(freqs[ind], i, ind) for i, ind in enumerate(inds)]
|
11 |
+
# O(n log n) where n = len(freqs)
|
12 |
+
heapq.heapify(freq_index)
|
13 |
+
return freq_index
|
14 |
+
|
15 |
+
|
16 |
+
def huffman_tree(heap):
|
17 |
+
'''Returns the Huffman tree given a min-heap of indices and frequencies.'''
|
18 |
+
# Add a counter in tuples for tiebreaking
|
19 |
+
t = len(heap)
|
20 |
+
# Runs for n iterations where n = len(heap)
|
21 |
+
while len(heap) > 1:
|
22 |
+
# Remove the smallest two nodes. O(log n)
|
23 |
+
freq1, i1, ind1 = heapq.heappop(heap)
|
24 |
+
freq2, i2, ind2 = heapq.heappop(heap)
|
25 |
+
# Create a parent node for these two nodes
|
26 |
+
parent_freq = freq1 + freq2
|
27 |
+
# The left child is the one with the lowest frequency
|
28 |
+
parent_ind = (ind1, ind2)
|
29 |
+
# Insert this parent node. O(log n)
|
30 |
+
heapq.heappush(heap, (parent_freq, t, parent_ind))
|
31 |
+
t += 1
|
32 |
+
code_tree = heap[0][2]
|
33 |
+
# Total runtime O(n log n).
|
34 |
+
return code_tree
|
35 |
+
|
36 |
+
|
37 |
+
def tv_huffman(code_tree, p):
|
38 |
+
'''
|
39 |
+
Returns the total variation and cross entropy (in bits) between a
|
40 |
+
distribution over tokens and the distribution induced by a Huffman
|
41 |
+
coding of (a subset of) the tokens.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
code_tree : tuple.
|
45 |
+
Huffman codes as represented by a binary tree. It might miss some
|
46 |
+
tokens.
|
47 |
+
p : array of size of the vocabulary.
|
48 |
+
The distribution over all tokens.
|
49 |
+
'''
|
50 |
+
tot_l1 = 0
|
51 |
+
# The tokens absent in the Huffman codes have probability 0
|
52 |
+
absence = np.ones_like(p)
|
53 |
+
tot_ce = 0
|
54 |
+
# Iterate leaves of the code tree. O(n)
|
55 |
+
stack = []
|
56 |
+
# Push the root and its depth onto the stack
|
57 |
+
stack.append((code_tree, 0))
|
58 |
+
while len(stack) > 0:
|
59 |
+
node, depth = stack.pop()
|
60 |
+
if type(node) is tuple:
|
61 |
+
# Expand the children
|
62 |
+
left_child, right_child = node
|
63 |
+
# Push the children and their depths onto the stack
|
64 |
+
stack.append((left_child, depth + 1))
|
65 |
+
stack.append((right_child, depth + 1))
|
66 |
+
else:
|
67 |
+
# A leaf node
|
68 |
+
ind = node
|
69 |
+
tot_l1 += abs(p[ind] - 2 ** (-depth))
|
70 |
+
absence[ind] = 0
|
71 |
+
# The KL divergence of true distribution || Huffman distribution
|
72 |
+
tot_ce += p[ind] * depth + p[ind] * np.log2(p[ind])
|
73 |
+
# Returns total variation
|
74 |
+
return 0.5 * (tot_l1 + np.sum(absence * p)), tot_ce
|
75 |
+
|
76 |
+
|
77 |
+
def total_variation(p, q):
|
78 |
+
'''Returns the total variation of two distributions over a finite set.'''
|
79 |
+
# We use 1-norm to compute total variation.
|
80 |
+
# d_TV(p, q) := sup_{A \in sigma} |p(A) - q(A)|
|
81 |
+
# = 1/2 * sum_{x \in X} |p(x) - q(x)| = 1/2 * ||p - q||_1
|
82 |
+
return 0.5 * np.sum(np.abs(p - q))
|
83 |
+
|
84 |
+
|
85 |
+
def invert_code_tree(code_tree):
|
86 |
+
'''Build a map from letters to codes'''
|
87 |
+
code = dict()
|
88 |
+
stack = []
|
89 |
+
stack.append((code_tree, ''))
|
90 |
+
while len(stack) > 0:
|
91 |
+
node, code_prefix = stack.pop()
|
92 |
+
if type(node) is tuple:
|
93 |
+
left, right = node
|
94 |
+
stack.append((left, code_prefix + '0'))
|
95 |
+
stack.append((right, code_prefix + '1'))
|
96 |
+
else:
|
97 |
+
code[node] = code_prefix
|
98 |
+
return code
|
99 |
+
|
100 |
+
|
101 |
+
def encode(code_tree, string):
|
102 |
+
'''Encode a string with a given Huffman coding.'''
|
103 |
+
code = invert_code_tree(code_tree)
|
104 |
+
encoded = ''
|
105 |
+
for letter in string:
|
106 |
+
encoded += code[letter]
|
107 |
+
return encoded
|
108 |
+
|
109 |
+
|
110 |
+
def decode(code_tree, encoded):
|
111 |
+
'''Decode an Huffman-encoded string.'''
|
112 |
+
decoded = []
|
113 |
+
state = code_tree
|
114 |
+
codes = [code for code in encoded]
|
115 |
+
# Terminate when there are no more codes and decoder state is resetted
|
116 |
+
while not (len(codes) == 0 and type(state) is tuple):
|
117 |
+
if type(state) is tuple:
|
118 |
+
# An internal node
|
119 |
+
left, right = state
|
120 |
+
try:
|
121 |
+
code = codes.pop(0)
|
122 |
+
except IndexError:
|
123 |
+
raise Exception('Decoder should stop at the end of the encoded string. The string may not be encoded by the specified Huffman coding.')
|
124 |
+
if code == 'l':
|
125 |
+
# Go left
|
126 |
+
state = left
|
127 |
+
else:
|
128 |
+
# Go right
|
129 |
+
state = right
|
130 |
+
else:
|
131 |
+
# A leaf node, decode a letter
|
132 |
+
decoded.append(state)
|
133 |
+
# Reset decoder state
|
134 |
+
state = code_tree
|
135 |
+
return decoded
|
136 |
+
|
137 |
+
|
138 |
+
def tree_depth(tree):
|
139 |
+
'''Returns the depth of a tree.'''
|
140 |
+
if type(tree) is tuple:
|
141 |
+
left, right = tree
|
142 |
+
return 1 + max(tree_depth(left), tree_depth(right))
|
143 |
+
else:
|
144 |
+
return 0
|
145 |
+
|
146 |
+
def tree_rank(tree):
|
147 |
+
'''Returns the rank of a tree.'''
|
148 |
+
if type(tree) is tuple:
|
149 |
+
left, right = tree
|
150 |
+
lr = tree_rank(left)
|
151 |
+
rr = tree_rank(right)
|
152 |
+
if lr == rr:
|
153 |
+
return lr + 1
|
154 |
+
else:
|
155 |
+
return max(lr, rr)
|
156 |
+
else:
|
157 |
+
return 0
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == '__main__':
|
161 |
+
# v = 256 ** 2
|
162 |
+
v = 5
|
163 |
+
p = np.random.dirichlet([1] * v)
|
164 |
+
print(sum(p))
|
165 |
+
# p = [0.7, 0.1, 0.05, 0.1, 0.05]
|
166 |
+
p = [0.99] + [.01 / 4] * 4
|
167 |
+
# heap = build_min_heap(p, [0, 1, 2, 4])
|
168 |
+
heap = build_min_heap(p)
|
169 |
+
# print(heap)
|
170 |
+
|
171 |
+
tree = huffman_tree(heap)
|
172 |
+
print(tree)
|
173 |
+
print(tv_huffman(tree, p))
|
174 |
+
# print(invert_code_tree(tree))
|
175 |
+
|
176 |
+
string = np.random.choice(v, 10, p=p)
|
177 |
+
# string = [0, 0, 2, 4, 1, 0, 2, 2]
|
178 |
+
print(list(string))
|
179 |
+
codes = encode(tree, string)
|
180 |
+
print(codes)
|
181 |
+
print(decode(tree, codes))
|