import heapq import numpy as np def build_min_heap(freqs, inds=None): '''Returns a min-heap of (frequency, token_index).''' inds = inds or range(len(freqs)) # Add a counter in tuples for tiebreaking freq_index = [(freqs[ind], i, ind) for i, ind in enumerate(inds)] # O(n log n) where n = len(freqs) heapq.heapify(freq_index) return freq_index def huffman_tree(heap): '''Returns the Huffman tree given a min-heap of indices and frequencies.''' # Add a counter in tuples for tiebreaking t = len(heap) # Runs for n iterations where n = len(heap) while len(heap) > 1: # Remove the smallest two nodes. O(log n) freq1, i1, ind1 = heapq.heappop(heap) freq2, i2, ind2 = heapq.heappop(heap) # Create a parent node for these two nodes parent_freq = freq1 + freq2 # The left child is the one with the lowest frequency parent_ind = (ind1, ind2) # Insert this parent node. O(log n) heapq.heappush(heap, (parent_freq, t, parent_ind)) t += 1 code_tree = heap[0][2] # Total runtime O(n log n). return code_tree def tv_huffman(code_tree, p): ''' Returns the total variation and cross entropy (in bits) between a distribution over tokens and the distribution induced by a Huffman coding of (a subset of) the tokens. Args: code_tree : tuple. Huffman codes as represented by a binary tree. It might miss some tokens. p : array of size of the vocabulary. The distribution over all tokens. ''' tot_l1 = 0 # The tokens absent in the Huffman codes have probability 0 absence = np.ones_like(p) tot_ce = 0 # Iterate leaves of the code tree. O(n) stack = [] # Push the root and its depth onto the stack stack.append((code_tree, 0)) while len(stack) > 0: node, depth = stack.pop() if type(node) is tuple: # Expand the children left_child, right_child = node # Push the children and their depths onto the stack stack.append((left_child, depth + 1)) stack.append((right_child, depth + 1)) else: # A leaf node ind = node tot_l1 += abs(p[ind] - 2 ** (-depth)) absence[ind] = 0 # The KL divergence of true distribution || Huffman distribution tot_ce += p[ind] * depth + p[ind] * np.log2(p[ind]) # Returns total variation return 0.5 * (tot_l1 + np.sum(absence * p)), tot_ce def total_variation(p, q): '''Returns the total variation of two distributions over a finite set.''' # We use 1-norm to compute total variation. # d_TV(p, q) := sup_{A \in sigma} |p(A) - q(A)| # = 1/2 * sum_{x \in X} |p(x) - q(x)| = 1/2 * ||p - q||_1 return 0.5 * np.sum(np.abs(p - q)) def invert_code_tree(code_tree): '''Build a map from letters to codes''' code = dict() stack = [] stack.append((code_tree, '')) while len(stack) > 0: node, code_prefix = stack.pop() if type(node) is tuple: left, right = node stack.append((left, code_prefix + '0')) stack.append((right, code_prefix + '1')) else: code[node] = code_prefix return code def encode(code_tree, string): '''Encode a string with a given Huffman coding.''' code = invert_code_tree(code_tree) encoded = '' for letter in string: encoded += code[letter] return encoded def decode(code_tree, encoded): '''Decode an Huffman-encoded string.''' decoded = [] state = code_tree codes = [code for code in encoded] # Terminate when there are no more codes and decoder state is resetted while not (len(codes) == 0 and type(state) is tuple): if type(state) is tuple: # An internal node left, right = state try: code = codes.pop(0) except IndexError: raise Exception('Decoder should stop at the end of the encoded string. The string may not be encoded by the specified Huffman coding.') if code == 'l': # Go left state = left else: # Go right state = right else: # A leaf node, decode a letter decoded.append(state) # Reset decoder state state = code_tree return decoded def tree_depth(tree): '''Returns the depth of a tree.''' if type(tree) is tuple: left, right = tree return 1 + max(tree_depth(left), tree_depth(right)) else: return 0 def tree_rank(tree): '''Returns the rank of a tree.''' if type(tree) is tuple: left, right = tree lr = tree_rank(left) rr = tree_rank(right) if lr == rr: return lr + 1 else: return max(lr, rr) else: return 0 if __name__ == '__main__': # v = 256 ** 2 v = 5 p = np.random.dirichlet([1] * v) print(sum(p)) # p = [0.7, 0.1, 0.05, 0.1, 0.05] p = [0.99] + [.01 / 4] * 4 # heap = build_min_heap(p, [0, 1, 2, 4]) heap = build_min_heap(p) # print(heap) tree = huffman_tree(heap) print(tree) print(tv_huffman(tree, p)) # print(invert_code_tree(tree)) string = np.random.choice(v, 10, p=p) # string = [0, 0, 2, 4, 1, 0, 2, 2] print(list(string)) codes = encode(tree, string) print(codes) print(decode(tree, codes))