Spaces:
Running on Zero
Running on Zero
| """ | |
| Arithmetic coder for neural text compression. | |
| Uses high-precision integer arithmetic (32-bit range) with proper | |
| renormalization and underflow handling. The encoder and decoder are | |
| perfectly symmetric — given the same sequence of CDFs, the decoder | |
| recovers the exact symbol sequence the encoder consumed. | |
| """ | |
| class ArithmeticEncoder: | |
| """Encodes symbols into a compressed bitstream using arithmetic coding. | |
| Bits are packed into a bytearray on the fly instead of stored as | |
| individual Python ints, cutting memory from O(n_bits * 28 bytes) to | |
| O(n_bits / 8). | |
| """ | |
| PRECISION = 32 | |
| FULL = 1 << PRECISION # 2^32 | |
| HALF = 1 << (PRECISION - 1) # 2^31 | |
| QUARTER = 1 << (PRECISION - 2) # 2^30 | |
| MAX_RANGE = FULL - 1 # 0xFFFFFFFF | |
| def __init__(self): | |
| self.low = 0 | |
| self.high = self.MAX_RANGE | |
| self.pending_bits = 0 | |
| self._buf = bytearray() | |
| self._cur_byte = 0 | |
| self._bits_in_cur = 0 | |
| self._total_bits = 0 | |
| def _write_bit(self, bit: int): | |
| """Pack a single bit into the output bytearray.""" | |
| self._cur_byte = (self._cur_byte << 1) | bit | |
| self._bits_in_cur += 1 | |
| self._total_bits += 1 | |
| if self._bits_in_cur == 8: | |
| self._buf.append(self._cur_byte) | |
| self._cur_byte = 0 | |
| self._bits_in_cur = 0 | |
| def _output_bit(self, bit: int): | |
| self._write_bit(bit) | |
| # Flush pending bits (opposite of the bit just emitted) | |
| for _ in range(self.pending_bits): | |
| self._write_bit(1 - bit) | |
| self.pending_bits = 0 | |
| def encode_symbol(self, cdf, symbol_index: int): | |
| """Encode a single symbol given its CDF. | |
| Args: | |
| cdf: Cumulative distribution function. Supports both list[int] | |
| and torch.Tensor (indexed with []). Length = num_symbols + 1. | |
| cdf[0] = 0, cdf[-1] = total. | |
| symbol_index: Index of the symbol to encode (0-based). | |
| """ | |
| total = int(cdf[-1]) | |
| rng = self.high - self.low + 1 | |
| sym_lo = int(cdf[symbol_index]) | |
| sym_hi = int(cdf[symbol_index + 1]) | |
| # Narrow the interval | |
| self.high = self.low + (rng * sym_hi) // total - 1 | |
| self.low = self.low + (rng * sym_lo) // total | |
| # Renormalize | |
| while True: | |
| if self.high < self.HALF: | |
| # Both in lower half — output 0 | |
| self._output_bit(0) | |
| self.low = self.low << 1 | |
| self.high = (self.high << 1) | 1 | |
| elif self.low >= self.HALF: | |
| # Both in upper half — output 1 | |
| self._output_bit(1) | |
| self.low = (self.low - self.HALF) << 1 | |
| self.high = ((self.high - self.HALF) << 1) | 1 | |
| elif self.low >= self.QUARTER and self.high < 3 * self.QUARTER: | |
| # Underflow / near-convergence | |
| self.pending_bits += 1 | |
| self.low = (self.low - self.QUARTER) << 1 | |
| self.high = ((self.high - self.QUARTER) << 1) | 1 | |
| else: | |
| break | |
| # Keep values in range | |
| self.low &= self.MAX_RANGE | |
| self.high &= self.MAX_RANGE | |
| def finish(self) -> bytes: | |
| """Finalize encoding and return compressed data as bytes.""" | |
| # Flush remaining state | |
| self.pending_bits += 1 | |
| if self.low < self.QUARTER: | |
| self._output_bit(0) | |
| else: | |
| self._output_bit(1) | |
| # Pad to byte boundary | |
| while self._bits_in_cur != 0: | |
| self._write_bit(0) | |
| return bytes(self._buf) | |
| def get_bit_count(self) -> int: | |
| """Return number of bits written so far (approximate).""" | |
| return self._total_bits + self.pending_bits | |
| class ArithmeticDecoder: | |
| """Decodes symbols from a compressed bitstream using arithmetic coding. | |
| Reads bits lazily from the compressed bytes instead of expanding | |
| every byte into 8 Python ints upfront. | |
| """ | |
| PRECISION = 32 | |
| FULL = 1 << PRECISION | |
| HALF = 1 << (PRECISION - 1) | |
| QUARTER = 1 << (PRECISION - 2) | |
| MAX_RANGE = FULL - 1 | |
| def __init__(self, data: bytes): | |
| self._data = data | |
| self._byte_pos = 0 | |
| self._bit_buf = 0 | |
| self._bits_left = 0 | |
| self.low = 0 | |
| self.high = self.MAX_RANGE | |
| # Read initial value | |
| self.value = 0 | |
| for _ in range(self.PRECISION): | |
| self.value = (self.value << 1) | self._read_bit() | |
| def _read_bit(self) -> int: | |
| if self._bits_left == 0: | |
| if self._byte_pos < len(self._data): | |
| self._bit_buf = self._data[self._byte_pos] | |
| self._byte_pos += 1 | |
| self._bits_left = 8 | |
| else: | |
| return 0 # Implicit trailing zeros | |
| self._bits_left -= 1 | |
| return (self._bit_buf >> self._bits_left) & 1 | |
| def decode_symbol(self, cdf) -> int: | |
| """Decode a single symbol given its CDF. | |
| Args: | |
| cdf: Same CDF format as encoder. Supports both list[int] and | |
| torch.Tensor. Length = num_symbols + 1, cdf[0] = 0, | |
| cdf[-1] = total. | |
| Returns: | |
| The symbol index (0-based). | |
| """ | |
| total = int(cdf[-1]) | |
| rng = self.high - self.low + 1 | |
| # Find which symbol the current value falls into | |
| scaled_value = ((self.value - self.low + 1) * total - 1) // rng | |
| # Binary search for the symbol | |
| num_symbols = len(cdf) - 1 | |
| lo, hi = 0, num_symbols - 1 | |
| while lo <= hi: | |
| mid = (lo + hi) // 2 | |
| if int(cdf[mid + 1]) <= scaled_value: | |
| lo = mid + 1 | |
| else: | |
| hi = mid - 1 | |
| symbol = lo | |
| sym_lo = int(cdf[symbol]) | |
| sym_hi = int(cdf[symbol + 1]) | |
| # Update range (must match encoder exactly) | |
| self.high = self.low + (rng * sym_hi) // total - 1 | |
| self.low = self.low + (rng * sym_lo) // total | |
| # Renormalize (must match encoder exactly) | |
| while True: | |
| if self.high < self.HALF: | |
| self.low = self.low << 1 | |
| self.high = (self.high << 1) | 1 | |
| self.value = (self.value << 1) | self._read_bit() | |
| elif self.low >= self.HALF: | |
| self.low = (self.low - self.HALF) << 1 | |
| self.high = ((self.high - self.HALF) << 1) | 1 | |
| self.value = ((self.value - self.HALF) << 1) | self._read_bit() | |
| elif self.low >= self.QUARTER and self.high < 3 * self.QUARTER: | |
| self.low = (self.low - self.QUARTER) << 1 | |
| self.high = ((self.high - self.QUARTER) << 1) | 1 | |
| self.value = ((self.value - self.QUARTER) << 1) | self._read_bit() | |
| else: | |
| break | |
| self.low &= self.MAX_RANGE | |
| self.high &= self.MAX_RANGE | |
| self.value &= self.MAX_RANGE | |
| return symbol | |