File size: 5,567 Bytes
178b66b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import heapq

import numpy as np


def build_min_heap(freqs, inds=None):
    '''Returns a min-heap of (frequency, token_index).'''
    inds = inds or range(len(freqs))
    # Add a counter in tuples for tiebreaking
    freq_index = [(freqs[ind], i, ind) for i, ind in enumerate(inds)]
    # O(n log n) where n = len(freqs)
    heapq.heapify(freq_index)
    return freq_index


def huffman_tree(heap):
    '''Returns the Huffman tree given a min-heap of indices and frequencies.'''
    # Add a counter in tuples for tiebreaking
    t = len(heap)
    # Runs for n iterations where n = len(heap)
    while len(heap) > 1:
        # Remove the smallest two nodes. O(log n)
        freq1, i1, ind1 = heapq.heappop(heap)
        freq2, i2, ind2 = heapq.heappop(heap)
        # Create a parent node for these two nodes
        parent_freq = freq1 + freq2
        # The left child is the one with the lowest frequency
        parent_ind = (ind1, ind2)
        # Insert this parent node. O(log n)
        heapq.heappush(heap, (parent_freq, t, parent_ind))
        t += 1
    code_tree = heap[0][2]
    # Total runtime O(n log n).
    return code_tree


def tv_huffman(code_tree, p):
    '''
    Returns the total variation and cross entropy (in bits) between a
    distribution over tokens and the distribution induced by a Huffman
    coding of (a subset of) the tokens.

    Args:
        code_tree : tuple.
            Huffman codes as represented by a binary tree. It might miss some
            tokens.
        p : array of size of the vocabulary.
            The distribution over all tokens.
    '''
    tot_l1 = 0
    # The tokens absent in the Huffman codes have probability 0
    absence = np.ones_like(p)
    tot_ce = 0
    # Iterate leaves of the code tree. O(n)
    stack = []
    # Push the root and its depth onto the stack
    stack.append((code_tree, 0))
    while len(stack) > 0:
        node, depth = stack.pop()
        if type(node) is tuple:
            # Expand the children
            left_child, right_child = node
            # Push the children and their depths onto the stack
            stack.append((left_child, depth + 1))
            stack.append((right_child, depth + 1))
        else:
            # A leaf node
            ind = node
            tot_l1 += abs(p[ind] - 2 ** (-depth))
            absence[ind] = 0
            # The KL divergence of true distribution || Huffman distribution
            tot_ce += p[ind] * depth + p[ind] * np.log2(p[ind])
    # Returns total variation
    return 0.5 * (tot_l1 + np.sum(absence * p)), tot_ce


def total_variation(p, q):
    '''Returns the total variation of two distributions over a finite set.'''
    # We use 1-norm to compute total variation.
    # d_TV(p, q) := sup_{A \in sigma} |p(A) - q(A)|
    # = 1/2 * sum_{x \in X} |p(x) - q(x)| = 1/2 * ||p - q||_1
    return 0.5 * np.sum(np.abs(p - q))


def invert_code_tree(code_tree):
    '''Build a map from letters to codes'''
    code = dict()
    stack = []
    stack.append((code_tree, ''))
    while len(stack) > 0:
        node, code_prefix = stack.pop()
        if type(node) is tuple:
            left, right = node
            stack.append((left, code_prefix + '0'))
            stack.append((right, code_prefix + '1'))
        else:
            code[node] = code_prefix
    return code


def encode(code_tree, string):
    '''Encode a string with a given Huffman coding.'''
    code = invert_code_tree(code_tree)
    encoded = ''
    for letter in string:
        encoded += code[letter]
    return encoded


def decode(code_tree, encoded):
    '''Decode an Huffman-encoded string.'''
    decoded = []
    state = code_tree
    codes = [code for code in encoded]
    # Terminate when there are no more codes and decoder state is resetted
    while not (len(codes) == 0 and type(state) is tuple):
        if type(state) is tuple:
            # An internal node
            left, right = state
            try:
                code = codes.pop(0)
            except IndexError:
                raise Exception('Decoder should stop at the end of the encoded string. The string may not be encoded by the specified Huffman coding.')
            if code == 'l':
                # Go left
                state = left
            else:
                # Go right
                state = right
        else:
            # A leaf node, decode a letter
            decoded.append(state)
            # Reset decoder state
            state = code_tree
    return decoded


def tree_depth(tree):
    '''Returns the depth of a tree.'''
    if type(tree) is tuple:
        left, right = tree
        return 1 + max(tree_depth(left), tree_depth(right))
    else:
        return 0

def tree_rank(tree):
    '''Returns the rank of a tree.'''
    if type(tree) is tuple:
        left, right = tree
        lr = tree_rank(left)
        rr = tree_rank(right)
        if lr == rr:
            return lr + 1
        else:
            return max(lr, rr)
    else:
        return 0


if __name__ == '__main__':
    # v = 256 ** 2
    v = 5
    p = np.random.dirichlet([1] * v)
    print(sum(p))
    # p = [0.7, 0.1, 0.05, 0.1, 0.05]
    p = [0.99] + [.01 / 4] * 4
    # heap = build_min_heap(p, [0, 1, 2, 4])
    heap = build_min_heap(p)
    # print(heap)

    tree = huffman_tree(heap)
    print(tree)
    print(tv_huffman(tree, p))
    # print(invert_code_tree(tree))

    string = np.random.choice(v, 10, p=p)
    # string = [0, 0, 2, 4, 1, 0, 2, 2]
    print(list(string))
    codes = encode(tree, string)
    print(codes)
    print(decode(tree, codes))