| |
| """ |
| Correct implementation of enumerative entropy coding as described in Han et al. (2008). |
| This version is fully self-contained, embedding all necessary data into the stream. |
| """ |
|
|
| import numpy as np |
| from typing import List, Dict, Tuple, Optional |
| from collections import Counter |
| import math |
|
|
| class ExpGolombCoder: |
| """Exponential-Golomb coding for non-negative integers.""" |
|
|
| @staticmethod |
| def encode(n: int) -> str: |
| """Encodes a non-negative integer n >= 0.""" |
| if n < 0: |
| raise ValueError("Exp-Golomb is for non-negative integers.") |
| n_plus_1 = n + 1 |
| binary = bin(n_plus_1)[2:] |
| leading_zeros = '0' * (len(binary) - 1) |
| return leading_zeros + binary |
|
|
| @staticmethod |
| def decode(bits: str, start_pos: int = 0) -> Tuple[int, int]: |
| """Decodes an exp-Golomb integer from a bit string.""" |
| pos = start_pos |
| leading_zeros = 0 |
| while pos < len(bits) and bits[pos] == '0': |
| leading_zeros += 1 |
| pos += 1 |
|
|
| if pos >= len(bits): |
| raise ValueError("Incomplete exp-Golomb code: no '1' bit found.") |
|
|
| num_bits_to_read = leading_zeros + 1 |
| if pos + num_bits_to_read > len(bits): |
| raise ValueError("Incomplete exp-Golomb code: not enough bits for value.") |
|
|
| code_bits = bits[pos:pos + num_bits_to_read] |
| value = int(code_bits, 2) - 1 |
| return value, pos + num_bits_to_read |
|
|
|
|
| class OptimizedBinomialTable: |
| """ |
| Computes and caches binomial coefficients C(n, k) using Python's arbitrary |
| precision integers to prevent overflow. |
| """ |
|
|
| def __init__(self): |
| self._cache = {} |
|
|
| def get(self, n: int, k: int) -> int: |
| if k < 0 or k > n: |
| return 0 |
| if k == 0 or k == n: |
| return 1 |
| if k > n // 2: |
| k = n - k |
|
|
| key = (n, k) |
| if key in self._cache: |
| return self._cache[key] |
|
|
| result = math.comb(n, k) |
| self._cache[key] = result |
| return result |
|
|
| def __getitem__(self, n: int): |
| return BinomialRow(self, n) |
|
|
|
|
| class BinomialRow: |
| """Helper class to support table[n][k] syntax.""" |
| def __init__(self, table: OptimizedBinomialTable, n: int): |
| self.table = table |
| self.n = n |
|
|
| def __getitem__(self, k: int) -> int: |
| return self.table.get(self.n, k) |
|
|
|
|
| class EnumerativeEncoder: |
| """ |
| An enumerative entropy coder aligned with the algorithm described in |
| "Entropy Coding Using Equiprobable Partitioning" by Han et al. (2008). |
| |
| This implementation is self-contained, writing all necessary information |
| (length, alphabet, counts, and positions) into the output stream. |
| """ |
|
|
| def __init__(self): |
| self.binom_table = OptimizedBinomialTable() |
|
|
| def _rank(self, n: int, k: int, positions: List[int]) -> int: |
| """Calculates the standard lexicographical rank of a combination.""" |
| index = 0 |
| for i, pos in enumerate(positions): |
| index += self.binom_table.get(pos, i + 1) |
| return index |
|
|
| def _unrank(self, n: int, k: int, index: int) -> List[int]: |
| """Converts a standard lexicographical rank back to a combination.""" |
| positions = [] |
| v_high = n - 1 |
| for i in range(k - 1, -1, -1): |
| v_low = i |
| |
| while v_low < v_high: |
| mid = (v_low + v_high + 1) // 2 |
| if self.binom_table.get(mid, i + 1) <= index: |
| v_low = mid |
| else: |
| v_high = mid - 1 |
|
|
| p_i = v_low |
| positions.append(p_i) |
| index -= self.binom_table.get(p_i, i + 1) |
| v_high = p_i - 1 |
|
|
| positions.reverse() |
| return positions |
|
|
| def encode(self, data: List[int]) -> bytes: |
| if not data: |
| return bytes() |
|
|
| n = len(data) |
| symbol_counts = Counter(data) |
|
|
| |
| sorted_symbols = sorted(symbol_counts.keys(), key=lambda s: symbol_counts[s]) |
| K = len(sorted_symbols) |
|
|
| bits = "" |
| |
| bits += ExpGolombCoder.encode(n) |
|
|
| |
| bits += ExpGolombCoder.encode(K) |
| for symbol in sorted_symbols: |
| bits += ExpGolombCoder.encode(symbol) |
|
|
| |
| for i in range(K - 1): |
| bits += ExpGolombCoder.encode(symbol_counts[sorted_symbols[i]]) |
|
|
| |
| available_indices = list(range(n)) |
|
|
| for i in range(K - 1): |
| symbol = sorted_symbols[i] |
| k = symbol_counts[symbol] |
| if k == 0: |
| continue |
|
|
| current_n = len(available_indices) |
|
|
| |
| symbol_positions_in_available = [ |
| j for j, original_idx in enumerate(available_indices) if data[original_idx] == symbol |
| ] |
|
|
| |
| use_complement = k > current_n / 2 |
| bits += '1' if use_complement else '0' |
|
|
| if use_complement: |
| complement_k = current_n - k |
| complement_positions = [j for j in range(current_n) if j not in symbol_positions_in_available] |
| index = self._rank(current_n, complement_k, complement_positions) |
| else: |
| index = self._rank(current_n, k, symbol_positions_in_available) |
|
|
| bits += ExpGolombCoder.encode(index) |
|
|
| |
| used_indices = {available_indices[j] for j in symbol_positions_in_available} |
| available_indices = [idx for idx in available_indices if idx not in used_indices] |
|
|
| |
| padding = (8 - len(bits) % 8) % 8 |
| bits += '0' * padding |
| encoded_bytes = bytes(int(bits[i:i+8], 2) for i in range(0, len(bits), 8)) |
|
|
| return encoded_bytes |
|
|
| def decode(self, encoded_bytes: bytes) -> List[int]: |
| if not encoded_bytes: |
| return [] |
|
|
| |
| bits = ''.join(format(byte, '08b') for byte in encoded_bytes) |
| pos = 0 |
|
|
| |
| n, pos = ExpGolombCoder.decode(bits, pos) |
|
|
| |
| K, pos = ExpGolombCoder.decode(bits, pos) |
| sorted_symbols = [] |
| for _ in range(K): |
| symbol, pos = ExpGolombCoder.decode(bits, pos) |
| sorted_symbols.append(symbol) |
|
|
| |
| counts = {} |
| decoded_count_sum = 0 |
| for i in range(K - 1): |
| symbol = sorted_symbols[i] |
| count, pos = ExpGolombCoder.decode(bits, pos) |
| counts[symbol] = count |
| decoded_count_sum += count |
|
|
| |
| last_symbol = sorted_symbols[-1] |
| counts[last_symbol] = n - decoded_count_sum |
|
|
| |
| result = [None] * n |
| available_indices = list(range(n)) |
|
|
| for i in range(K - 1): |
| symbol = sorted_symbols[i] |
| k = counts[symbol] |
| if k == 0: |
| continue |
|
|
| current_n = len(available_indices) |
|
|
| |
| use_complement = (bits[pos] == '1') |
| pos += 1 |
|
|
| index, pos = ExpGolombCoder.decode(bits, pos) |
|
|
| if use_complement: |
| complement_k = current_n - k |
| complement_positions = self._unrank(current_n, complement_k, index) |
| positions_in_available = [j for j in range(current_n) if j not in complement_positions] |
| else: |
| positions_in_available = self._unrank(current_n, k, index) |
|
|
| |
| used_indices = set() |
| for rel_pos in positions_in_available: |
| abs_pos = available_indices[rel_pos] |
| result[abs_pos] = symbol |
| used_indices.add(abs_pos) |
|
|
| |
| available_indices = [idx for idx in available_indices if idx not in used_indices] |
|
|
| |
| for i in range(n): |
| if result[i] is None: |
| result[i] = last_symbol |
|
|
| return result |
|
|