dai commited on
Commit
178b66b
β€’
1 Parent(s): b3fbbe5

first release

Browse files
Files changed (3) hide show
  1. README.md +4 -4
  2. app.py +359 -0
  3. huffman.py +181 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Stego Lm
3
- emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.28.0
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Stego LM
3
+ emoji: πŸ”’πŸ‘€πŸ™ˆ
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.47.1
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
11
  ---
12
 
13
+ Hide the hiding.
app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # An demo of linguistic steganography with patient-Huffman algorithm.
3
+ # We use symmetric key cryptography to en/decrypt.
4
+ #
5
+ # Reference:
6
+ # Dai FZ, Cai Z. Towards Near-imperceptible Steganographic Text. ACL 2019.
7
+
8
+ import nacl.secret
9
+ import nacl.utils
10
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch as th
14
+
15
+ from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
16
+
17
+ # model_name = 'gpt2-xl'
18
+ # XXX Use GPT-2-small for less compute
19
+ model_name = 'gpt2'
20
+ lm = GPT2LMHeadModel.from_pretrained(model_name)
21
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
22
+
23
+ def bits_to_recover(max_plaintext_length):
24
+ return (max_plaintext_length + 40) * 8
25
+
26
+ def p_next_token(prefix, cache=None, allow_eos=True):
27
+ t_prefix = th.as_tensor(prefix)
28
+ with th.no_grad():
29
+ if cache:
30
+ # Incremental decoding. Input one token at a time with cache.
31
+ lm_out = lm.forward(input_ids=t_prefix[-1:], use_cache=True, past_key_values=cache)
32
+ else:
33
+ lm_out = lm.forward(input_ids=t_prefix, use_cache=True)
34
+ if allow_eos:
35
+ # Assume EOS is the last token in the vocabulary.
36
+ p_next_token = lm_out.logits[-1].softmax(dim=-1)
37
+ else:
38
+ p_next_token = lm_out.logits[-1, :-1].softmax(dim=-1)
39
+ return p_next_token.numpy(), lm_out.past_key_values
40
+
41
+ def embed_bits(coin_flips, prefix, tv_threshold=0.1, max_sequence_length=400):
42
+ '''We use a sequence of coin flips to control the generation of token
43
+ indices from a language model. This returns _a sequence_ as defined by
44
+ the language model, e.g. sentence, paragraph.'''
45
+ # ind = tokenizer.bos_token_id
46
+ # prefix = [ind]
47
+
48
+ hidden_prefix_ind = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
49
+ n_hidden_prefix_ind = len(hidden_prefix_ind)
50
+ done_hiding = False
51
+ p, kv = p_next_token(hidden_prefix_ind, allow_eos=done_hiding)
52
+ n_skips = 0
53
+ n_bits_encoded = 0
54
+ n_tail_fill = 0
55
+ ind = None
56
+ prefix_inds = []
57
+ # Terminate the generation after we generate the EOS token
58
+ # XXX to save computation, we terminate as soon as all bits are hidden.
59
+ while not done_hiding and n_hidden_prefix_ind + len(prefix_inds) < max_sequence_length and ind != tokenizer.eos_token_id:
60
+ # There is still some cipher text to hide
61
+ if coin_flips:
62
+ # Build Huffman codes for the conditional distribution
63
+ heap = build_min_heap(p)
64
+ hc = huffman_tree(heap)
65
+ # print(hc)
66
+ # Check if the total variation is low enough
67
+ # print(len(prefix_inds) - 1, tv_huffman(hc, p))
68
+ # print(tv_huffman(hc, p)[0], tv_threshold)
69
+ if tv_huffman(hc, p)[0] < tv_threshold:
70
+ # Huffman-decode the cipher text into a token
71
+ # Consume the cipher text until a token is generated
72
+ decoder_state = hc
73
+ while type(decoder_state) is tuple:
74
+ left, right = decoder_state
75
+ try:
76
+ bit = coin_flips.pop(0)
77
+ n_bits_encoded += 1
78
+ except IndexError:
79
+ # No more cipher text. Pad with random bits
80
+ bit = np.random.choice(2)
81
+ n_tail_fill += 1
82
+ # 0 => left, 1 => right
83
+ decoder_state = left if bit == '0' else right
84
+ # Decoder settles in a leaf node
85
+ ind = decoder_state
86
+ prefix_inds.append(ind)
87
+ yield prefix_inds
88
+ done_hiding = not bool(coin_flips)
89
+ p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
90
+ continue
91
+ # Forward sample according to LM normally
92
+ n_skips += 1 if coin_flips else 0
93
+ ind = np.random.choice(tokenizer.vocab_size if done_hiding else tokenizer.vocab_size - 1, p=p)
94
+ prefix_inds.append(ind)
95
+ yield prefix_inds
96
+ p, kv = p_next_token(hidden_prefix_ind + prefix_inds, kv, done_hiding)
97
+ # Drop the EOS index
98
+ print(prefix_inds)
99
+ print(len(prefix_inds), n_skips, n_bits_encoded, n_tail_fill)
100
+ if prefix_inds[-1] == tokenizer.eos_token_id:
101
+ prefix_inds = prefix_inds[:-1]
102
+ yield prefix_inds
103
+
104
+ def recover_bits(token_inds, tv_threshold, bits_to_recover, prefix):
105
+ remaining_bits = bits_to_recover
106
+ hidden_prefix_inds = [tokenizer.bos_token_id] + tokenizer.encode(prefix)
107
+ p, kv = p_next_token(hidden_prefix_inds, allow_eos=False)
108
+ cipher_text = []
109
+ # Terminate the generation after we have consumed all indices or
110
+ # have extracted all bits
111
+ while token_inds and 0 < remaining_bits:
112
+ # Build Huffman codes for the conditional distribution
113
+ heap = build_min_heap(p)
114
+ hc = huffman_tree(heap)
115
+ # Check if the total variation is low enough
116
+ if tv_huffman(hc, p)[0] < tv_threshold:
117
+ # We have controlled this step. Some bits are hidden.
118
+ code = invert_code_tree(hc)
119
+ # Look up the Huffman code for the token.
120
+ ind = token_inds.pop(0)
121
+ # Convert the Huffman code into bits
122
+ # left => 0, right => 1
123
+ cipher_text_fragment = code[ind]
124
+ # Truncate possible trailing paddings
125
+ cipher_text += cipher_text_fragment[:remaining_bits]
126
+ remaining_bits -= len(cipher_text_fragment)
127
+ yield cipher_text
128
+ # print(remaining_bits)
129
+ hidden_prefix_inds.append(ind)
130
+ p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
131
+ else:
132
+ # We did not control this step. Skip.
133
+ hidden_prefix_inds.append(token_inds.pop(0))
134
+ p, kv = p_next_token(hidden_prefix_inds, cache=kv, allow_eos=False)
135
+ print(cipher_text, len(cipher_text), bits_to_recover)
136
+ yield cipher_text
137
+
138
+
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown('''
141
+ # Linguistic steganography demo with ``patient-Huffman`` algorithm
142
+ Instead of sending secrets in plaintext or in ciphertext, we can "hide the hiding" by embedding the encrypted secret in a natural looking message.
143
+
144
+ ## Usage for message sender
145
+ 1. Type a short message. Click Encrypt to generate the ciphertext (encrypted text).
146
+ 2. Click Hide to generate the stegotext/covertext.
147
+
148
+ ## Usage for message receiver
149
+ 1. Copy-paste the received stegotext/covertext into the stegotext box. Click Recover to extract the hidden ciphertext.
150
+ 2. Click Decrypt to decipher the original message.
151
+ ''')
152
+
153
+ with gr.Accordion(
154
+ 'Secrets shared between sender and receiver',
155
+ open=False,
156
+ ):
157
+ # Shared secrets and parameters.
158
+ gr.Markdown('''
159
+ - The proposed stegosystem is agnostic to the choice of cryptosystem. We use the symmetric key encryption implemented in `pyNaCl` library.
160
+ - An encryption key is randomly generated, you can refresh the page to get a different one.
161
+ - The _choice_ of language model is a shared secret. Due to computation resource constraints, we use GPT-2 as an example.
162
+ - The communicating parties can share a prefix to further control the stegotext to appear more appropriate for the channel, e.g., blog posts, social media messages. Take extra care of the whitespaces.
163
+ - Imperceptibility threshold controls how much the distribution of stegotexts is allowed to deviate from the language model. Lower imperceptibility threshold produces longer stegotext.
164
+
165
+
166
+ Reference: Dai FZ, Cai Z. [Towards Near-imperceptible Steganographic Text](https://arxiv.org/abs/1907.06679). ACL 2019.
167
+ ''')
168
+ state = gr.State()
169
+ with gr.Row():
170
+ tb_shared_key = gr.Textbox(
171
+ label='encryption key (hex)',
172
+ value=lambda : nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE).hex(),
173
+ interactive=True,
174
+ scale=1,
175
+ lines=3,
176
+ )
177
+ # dp_shared_lm = gr.Dropdown(
178
+ # label='language model',
179
+ # choices=[
180
+ # 'GPT-2',
181
+ # # 'GPT-3',
182
+ # ],
183
+ # value='GPT-2',
184
+ # )
185
+ s_shared_imp = gr.Slider(
186
+ label='imperceptibility threshold',
187
+ minimum=0,
188
+ maximum=1,
189
+ value=0.4,
190
+ scale=1,
191
+ )
192
+ s_shared_max_plaintext_len = gr.Slider(
193
+ label='max plaintext length',
194
+ minimum=4,
195
+ maximum=32,
196
+ step=1,
197
+ value=18,
198
+ scale=1,
199
+ )
200
+ with gr.Column(scale=1):
201
+ tb_shared_prefix = gr.Textbox(
202
+ label='prefix',
203
+ value='',
204
+ )
205
+ gr.Examples(
206
+ [
207
+ 'best dessert recipe: ',
208
+ 'def solve(x):',
209
+ 'breaking news ',
210
+ 'πŸ€—πŸ”’',
211
+ ],
212
+ tb_shared_prefix,
213
+ cache_examples=False,
214
+ )
215
+
216
+ with gr.Row():
217
+ with gr.Box():
218
+ with gr.Column():
219
+ # Sender
220
+ gr.Markdown('## Sender')
221
+
222
+ # Plain text
223
+ tb_sender_plaintext = gr.Textbox(
224
+ label='plaintext',
225
+ value='gold in top drawer',
226
+ )
227
+ btn_encrypt = gr.Button('πŸ”’ Encrypt')
228
+
229
+ # Encrypt
230
+ # Cipher text
231
+ tb_sender_ciphertext = gr.Textbox(
232
+ label='ciphertext (hex)',
233
+ )
234
+ btn_hide = gr.Button('🫣 Hide', interactive=False)
235
+
236
+ # Hide
237
+ # Cover text
238
+ tb_sender_stegotext = gr.Textbox(
239
+ label='stegotext',
240
+ )
241
+
242
+ with gr.Box():
243
+ with gr.Column():
244
+ # Receiver
245
+ gr.Markdown('## Receiver')
246
+ # Cover text
247
+ tb_receiver_stegotext = gr.Textbox(
248
+ label='stegotext',
249
+ )
250
+ btn_recover = gr.Button('πŸ”Ž Recover')
251
+ # Cipher text
252
+ tb_receiver_ciphertext = gr.Textbox(
253
+ label='recovered ciphertext (hex)',
254
+ )
255
+ btn_decrypt = gr.Button('πŸ”“ Decrypt', interactive=True)
256
+ # Plain text
257
+ tb_receiver_plaintext = gr.Textbox(
258
+ label='deciphered plaintext',
259
+ )
260
+
261
+ gr.Markdown('''
262
+ ## Known issues
263
+ 1. The ciphertext recovered by the receiver might not match the original ciphertext. This is due to LLM tokenization mismatch. This is a fundamental challenge and for now, just Encrypt again (to use a different nonce) and go through the sender's process again.
264
+ 2. The stegotext looks incoherent. GPT-2 small is used for the demo and its fluency is quite limited. A stronger LLM will alleviate this problem. A smaller imperceptibility threshold should also help.
265
+ ''')
266
+
267
+ # Link the UI to handlers
268
+ def encrypt(saved_state, key_in_hex, plaintext, max_plaintext_length):
269
+ shared_key = bytes.fromhex(key_in_hex)
270
+ # print(saved_state)
271
+ if saved_state is None:
272
+ # Create the secret boxes if they have not been created.
273
+ sender_box = nacl.secret.SecretBox(shared_key)
274
+ receiver_box = nacl.secret.SecretBox(shared_key)
275
+ saved_state = sender_box, receiver_box
276
+ else:
277
+ sender_box, receiver_box = saved_state
278
+ print('Encode:', bytes(plaintext, 'utf8'), len(bytes(plaintext, 'utf8')))
279
+ utf8_encoded_plaintext = bytes(plaintext, 'utf8')
280
+ if len(utf8_encoded_plaintext) > max_plaintext_length:
281
+ raise gr.Error('Plaintext is too long. Try a shorter one or increase the max plaintext length.')
282
+ else:
283
+ # Pad the plaintext to the maximum length.
284
+ utf8_encoded_plaintext += bytes(' ' * (max_plaintext_length - len(utf8_encoded_plaintext)), encoding='utf8')
285
+ ciphertext = sender_box.encrypt(utf8_encoded_plaintext)
286
+ print('Encrypt:', plaintext, len(plaintext), ciphertext, len(ciphertext), len(ciphertext.hex()))
287
+ return [
288
+ saved_state,
289
+ ciphertext.hex(),
290
+ gr.Button.update(interactive=True),
291
+ ]
292
+
293
+ def decrypt(saved_state, ciphertext, key_in_hex):
294
+ shared_key = bytes.fromhex(key_in_hex)
295
+ if saved_state is None:
296
+ # Create the secret boxes if they have not been created.
297
+ sender_box = nacl.secret.SecretBox(shared_key)
298
+ receiver_box = nacl.secret.SecretBox(shared_key)
299
+ saved_state = sender_box, receiver_box
300
+ else:
301
+ sender_box, receiver_box = saved_state
302
+ try:
303
+ utf8_encoded_plaintext = receiver_box.decrypt(bytes.fromhex(ciphertext))
304
+ print('Decrypt:', ciphertext, len(ciphertext), utf8_encoded_plaintext, len(utf8_encoded_plaintext))
305
+ return [
306
+ saved_state,
307
+ utf8_encoded_plaintext.decode('utf8'),
308
+ ]
309
+ except:
310
+ raise gr.Error('Decryption failed. Likely due to tokenization mismatch. Try Encrypting again.')
311
+
312
+ def hide(ciphertext, tv_threshold, shared_prefix):
313
+ # Convert hex to bits
314
+ ba = bytes.fromhex(ciphertext)
315
+ bits = [b for h in ba for b in f'{h:08b}']
316
+ print('Hide:', ciphertext, bits, len(bits))
317
+ embed_gen = embed_bits(bits, shared_prefix, tv_threshold, lm.config.n_ctx // 2)
318
+ for inds in embed_gen:
319
+ yield tokenizer.decode(inds)
320
+
321
+ def recover(stegotext, tv_threshold, max_plaintext_length, shared_prefix):
322
+ inds = tokenizer.encode(stegotext)
323
+ print('Recover:', stegotext, inds, len(inds))
324
+ n_bits_to_recover = bits_to_recover(max_plaintext_length)
325
+ recover_gen = recover_bits(inds, tv_threshold, n_bits_to_recover, shared_prefix)
326
+ for bits in recover_gen:
327
+ yield ''.join(bits)
328
+ ba = bytearray()
329
+ # Convert bits to bytearray
330
+ for i in range(0, len(bits), 8):
331
+ ba.append(int(''.join(bits[i:i+8]), 2))
332
+ yield ba.hex()
333
+
334
+ btn_encrypt.click(
335
+ encrypt,
336
+ [state, tb_shared_key, tb_sender_plaintext, s_shared_max_plaintext_len],
337
+ [state, tb_sender_ciphertext, btn_hide],
338
+ )
339
+ btn_hide.click(
340
+ hide,
341
+ [tb_sender_ciphertext, s_shared_imp, tb_shared_prefix],
342
+ [tb_sender_stegotext],
343
+ )
344
+ btn_recover.click(
345
+ recover,
346
+ [tb_receiver_stegotext, s_shared_imp, s_shared_max_plaintext_len, tb_shared_prefix],
347
+ [tb_receiver_ciphertext],
348
+ )
349
+ btn_decrypt.click(
350
+ decrypt,
351
+ [state, tb_receiver_ciphertext, tb_shared_key],
352
+ [state, tb_receiver_plaintext],
353
+ )
354
+
355
+
356
+ if __name__ == '__main__':
357
+ demo.queue(concurrency_count=10)
358
+ demo.launch()
359
+ # demo.launch(share=True)
huffman.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import heapq
2
+
3
+ import numpy as np
4
+
5
+
6
+ def build_min_heap(freqs, inds=None):
7
+ '''Returns a min-heap of (frequency, token_index).'''
8
+ inds = inds or range(len(freqs))
9
+ # Add a counter in tuples for tiebreaking
10
+ freq_index = [(freqs[ind], i, ind) for i, ind in enumerate(inds)]
11
+ # O(n log n) where n = len(freqs)
12
+ heapq.heapify(freq_index)
13
+ return freq_index
14
+
15
+
16
+ def huffman_tree(heap):
17
+ '''Returns the Huffman tree given a min-heap of indices and frequencies.'''
18
+ # Add a counter in tuples for tiebreaking
19
+ t = len(heap)
20
+ # Runs for n iterations where n = len(heap)
21
+ while len(heap) > 1:
22
+ # Remove the smallest two nodes. O(log n)
23
+ freq1, i1, ind1 = heapq.heappop(heap)
24
+ freq2, i2, ind2 = heapq.heappop(heap)
25
+ # Create a parent node for these two nodes
26
+ parent_freq = freq1 + freq2
27
+ # The left child is the one with the lowest frequency
28
+ parent_ind = (ind1, ind2)
29
+ # Insert this parent node. O(log n)
30
+ heapq.heappush(heap, (parent_freq, t, parent_ind))
31
+ t += 1
32
+ code_tree = heap[0][2]
33
+ # Total runtime O(n log n).
34
+ return code_tree
35
+
36
+
37
+ def tv_huffman(code_tree, p):
38
+ '''
39
+ Returns the total variation and cross entropy (in bits) between a
40
+ distribution over tokens and the distribution induced by a Huffman
41
+ coding of (a subset of) the tokens.
42
+
43
+ Args:
44
+ code_tree : tuple.
45
+ Huffman codes as represented by a binary tree. It might miss some
46
+ tokens.
47
+ p : array of size of the vocabulary.
48
+ The distribution over all tokens.
49
+ '''
50
+ tot_l1 = 0
51
+ # The tokens absent in the Huffman codes have probability 0
52
+ absence = np.ones_like(p)
53
+ tot_ce = 0
54
+ # Iterate leaves of the code tree. O(n)
55
+ stack = []
56
+ # Push the root and its depth onto the stack
57
+ stack.append((code_tree, 0))
58
+ while len(stack) > 0:
59
+ node, depth = stack.pop()
60
+ if type(node) is tuple:
61
+ # Expand the children
62
+ left_child, right_child = node
63
+ # Push the children and their depths onto the stack
64
+ stack.append((left_child, depth + 1))
65
+ stack.append((right_child, depth + 1))
66
+ else:
67
+ # A leaf node
68
+ ind = node
69
+ tot_l1 += abs(p[ind] - 2 ** (-depth))
70
+ absence[ind] = 0
71
+ # The KL divergence of true distribution || Huffman distribution
72
+ tot_ce += p[ind] * depth + p[ind] * np.log2(p[ind])
73
+ # Returns total variation
74
+ return 0.5 * (tot_l1 + np.sum(absence * p)), tot_ce
75
+
76
+
77
+ def total_variation(p, q):
78
+ '''Returns the total variation of two distributions over a finite set.'''
79
+ # We use 1-norm to compute total variation.
80
+ # d_TV(p, q) := sup_{A \in sigma} |p(A) - q(A)|
81
+ # = 1/2 * sum_{x \in X} |p(x) - q(x)| = 1/2 * ||p - q||_1
82
+ return 0.5 * np.sum(np.abs(p - q))
83
+
84
+
85
+ def invert_code_tree(code_tree):
86
+ '''Build a map from letters to codes'''
87
+ code = dict()
88
+ stack = []
89
+ stack.append((code_tree, ''))
90
+ while len(stack) > 0:
91
+ node, code_prefix = stack.pop()
92
+ if type(node) is tuple:
93
+ left, right = node
94
+ stack.append((left, code_prefix + '0'))
95
+ stack.append((right, code_prefix + '1'))
96
+ else:
97
+ code[node] = code_prefix
98
+ return code
99
+
100
+
101
+ def encode(code_tree, string):
102
+ '''Encode a string with a given Huffman coding.'''
103
+ code = invert_code_tree(code_tree)
104
+ encoded = ''
105
+ for letter in string:
106
+ encoded += code[letter]
107
+ return encoded
108
+
109
+
110
+ def decode(code_tree, encoded):
111
+ '''Decode an Huffman-encoded string.'''
112
+ decoded = []
113
+ state = code_tree
114
+ codes = [code for code in encoded]
115
+ # Terminate when there are no more codes and decoder state is resetted
116
+ while not (len(codes) == 0 and type(state) is tuple):
117
+ if type(state) is tuple:
118
+ # An internal node
119
+ left, right = state
120
+ try:
121
+ code = codes.pop(0)
122
+ except IndexError:
123
+ raise Exception('Decoder should stop at the end of the encoded string. The string may not be encoded by the specified Huffman coding.')
124
+ if code == 'l':
125
+ # Go left
126
+ state = left
127
+ else:
128
+ # Go right
129
+ state = right
130
+ else:
131
+ # A leaf node, decode a letter
132
+ decoded.append(state)
133
+ # Reset decoder state
134
+ state = code_tree
135
+ return decoded
136
+
137
+
138
+ def tree_depth(tree):
139
+ '''Returns the depth of a tree.'''
140
+ if type(tree) is tuple:
141
+ left, right = tree
142
+ return 1 + max(tree_depth(left), tree_depth(right))
143
+ else:
144
+ return 0
145
+
146
+ def tree_rank(tree):
147
+ '''Returns the rank of a tree.'''
148
+ if type(tree) is tuple:
149
+ left, right = tree
150
+ lr = tree_rank(left)
151
+ rr = tree_rank(right)
152
+ if lr == rr:
153
+ return lr + 1
154
+ else:
155
+ return max(lr, rr)
156
+ else:
157
+ return 0
158
+
159
+
160
+ if __name__ == '__main__':
161
+ # v = 256 ** 2
162
+ v = 5
163
+ p = np.random.dirichlet([1] * v)
164
+ print(sum(p))
165
+ # p = [0.7, 0.1, 0.05, 0.1, 0.05]
166
+ p = [0.99] + [.01 / 4] * 4
167
+ # heap = build_min_heap(p, [0, 1, 2, 4])
168
+ heap = build_min_heap(p)
169
+ # print(heap)
170
+
171
+ tree = huffman_tree(heap)
172
+ print(tree)
173
+ print(tv_huffman(tree, p))
174
+ # print(invert_code_tree(tree))
175
+
176
+ string = np.random.choice(v, 10, p=p)
177
+ # string = [0, 0, 2, 4, 1, 0, 2, 2]
178
+ print(list(string))
179
+ codes = encode(tree, string)
180
+ print(codes)
181
+ print(decode(tree, codes))