Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Arithmetic coder.""" | |
| import io | |
| import math | |
| import random | |
| import typing as tp | |
| import torch | |
| from ..binary import BitPacker, BitUnpacker | |
| def build_stable_quantized_cdf( | |
| pdf: torch.Tensor, | |
| total_range_bits: int, | |
| roundoff: float = 1e-8, | |
| min_range: int = 2, | |
| check: bool = True, | |
| ) -> torch.Tensor: | |
| """Turn the given PDF into a quantized CDF that splits | |
| [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional | |
| to the PDF. | |
| Args: | |
| pdf (torch.Tensor): probability distribution, shape should be `[N]`. | |
| total_range_bits (int): see `ArithmeticCoder`, the typical range we expect | |
| during the coding process is `[0, 2 ** total_range_bits - 1]`. | |
| roundoff (float): will round the pdf up to that level to remove difference coming | |
| from e.g. evaluating the Language Model on different architectures. | |
| min_range (int): minimum range width. Should always be at least 2 for numerical | |
| stability. Use this to avoid pathological behavior is a value | |
| that is expected to be rare actually happens in real life. | |
| check (bool): if True, checks that nothing bad happened, can be deactivated for speed. | |
| """ | |
| pdf = pdf.detach() | |
| if roundoff: | |
| pdf = (pdf / roundoff).floor() * roundoff | |
| # interpolate with uniform distribution to achieve desired minimum probability. | |
| total_range = 2**total_range_bits | |
| cardinality = len(pdf) | |
| alpha = min_range * cardinality / total_range | |
| assert alpha <= 1, "you must reduce min_range" | |
| ranges = (((1 - alpha) * total_range) * pdf).floor().long() | |
| ranges += min_range | |
| quantized_cdf = torch.cumsum(ranges, dim=-1) | |
| if min_range < 2: | |
| raise ValueError("min_range must be at least 2.") | |
| if check: | |
| assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] | |
| if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: | |
| raise ValueError("You must increase your total_range_bits.") | |
| return quantized_cdf | |
| class ArithmeticCoder: | |
| """ArithmeticCoder, | |
| Let us take a distribution `p` over `N` symbols, and assume we have a stream | |
| of random variables `s_t` sampled from `p`. Let us assume that we have a budget | |
| of `B` bits that we can afford to write on device. There are `2**B` possible numbers, | |
| corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single | |
| sequence `(s_t)` by doing the following: | |
| 1) Initialize the current range to` [0 ** 2 B - 1]`. | |
| 2) For each time step t, split the current range into contiguous chunks, | |
| one for each possible outcome, with size roughly proportional to `p`. | |
| For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks | |
| would be `{[0, 2], [3, 3]}`. | |
| 3) Select the chunk corresponding to `s_t`, and replace the current range with this. | |
| 4) When done encoding all the values, just select any value remaining in the range. | |
| You will notice that this procedure can fail: for instance if at any point in time | |
| the range is smaller than `N`, then we can no longer assign a non-empty chunk to each | |
| possible outcome. Intuitively, the more likely a value is, the less the range width | |
| will reduce, and the longer we can go on encoding values. This makes sense: for any efficient | |
| coding scheme, likely outcomes would take less bits, and more of them can be coded | |
| with a fixed budget. | |
| In practice, we do not know `B` ahead of time, but we have a way to inject new bits | |
| when the current range decreases below a given limit (given by `total_range_bits`), without | |
| having to redo all the computations. If we encode mostly likely values, we will seldom | |
| need to inject new bits, but a single rare value can deplete our stock of entropy! | |
| In this explanation, we assumed that the distribution `p` was constant. In fact, the present | |
| code works for any sequence `(p_t)` possibly different for each timestep. | |
| We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller | |
| the KL between the true distribution and `p_t`, the most efficient the coding will be. | |
| Args: | |
| fo (IO[bytes]): file-like object to which the bytes will be written to. | |
| total_range_bits (int): the range `M` described above is `2 ** total_range_bits. | |
| Any time the current range width fall under this limit, new bits will | |
| be injected to rescale the initial range. | |
| """ | |
| def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): | |
| assert total_range_bits <= 30 | |
| self.total_range_bits = total_range_bits | |
| self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. | |
| self.low: int = 0 | |
| self.high: int = 0 | |
| self.max_bit: int = -1 | |
| self._dbg: tp.List[tp.Any] = [] | |
| self._dbg2: tp.List[tp.Any] = [] | |
| def delta(self) -> int: | |
| """Return the current range width.""" | |
| return self.high - self.low + 1 | |
| def _flush_common_prefix(self): | |
| # If self.low and self.high start with the sames bits, | |
| # those won't change anymore as we always just increase the range | |
| # by powers of 2, and we can flush them out to the bit stream. | |
| assert self.high >= self.low, (self.low, self.high) | |
| assert self.high < 2 ** (self.max_bit + 1) | |
| while self.max_bit >= 0: | |
| b1 = self.low >> self.max_bit | |
| b2 = self.high >> self.max_bit | |
| if b1 == b2: | |
| self.low -= b1 << self.max_bit | |
| self.high -= b1 << self.max_bit | |
| assert self.high >= self.low, (self.high, self.low, self.max_bit) | |
| assert self.low >= 0 | |
| self.max_bit -= 1 | |
| self.packer.push(b1) | |
| else: | |
| break | |
| def push(self, symbol: int, quantized_cdf: torch.Tensor): | |
| """Push the given symbol on the stream, flushing out bits | |
| if possible. | |
| Args: | |
| symbol (int): symbol to encode with the AC. | |
| quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` | |
| to build this from your pdf estimate. | |
| """ | |
| while self.delta < 2**self.total_range_bits: | |
| self.low *= 2 | |
| self.high = self.high * 2 + 1 | |
| self.max_bit += 1 | |
| range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() | |
| range_high = quantized_cdf[symbol].item() - 1 | |
| effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) | |
| effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) | |
| assert self.low <= self.high | |
| self.high = self.low + effective_high | |
| self.low = self.low + effective_low | |
| assert self.low <= self.high, ( | |
| effective_low, | |
| effective_high, | |
| range_low, | |
| range_high, | |
| ) | |
| self._dbg.append((self.low, self.high)) | |
| self._dbg2.append((self.low, self.high)) | |
| outs = self._flush_common_prefix() | |
| assert self.low <= self.high | |
| assert self.max_bit >= -1 | |
| assert self.max_bit <= 61, self.max_bit | |
| return outs | |
| def flush(self): | |
| """Flush the remaining information to the stream.""" | |
| while self.max_bit >= 0: | |
| b1 = (self.low >> self.max_bit) & 1 | |
| self.packer.push(b1) | |
| self.max_bit -= 1 | |
| self.packer.flush() | |
| class ArithmeticDecoder: | |
| """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. | |
| Note that this must be called with **exactly** the same parameters and sequence | |
| of quantized cdf as the arithmetic encoder or the wrong values will be decoded. | |
| If the AC encoder current range is [L, H], with `L` and `H` having the some common | |
| prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. | |
| For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside | |
| `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained | |
| for a specific sequence of symbols and a binary-search allows us to decode those symbols. | |
| At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, | |
| and we will need to read new bits from the stream and repeat the process. | |
| """ | |
| def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): | |
| self.total_range_bits = total_range_bits | |
| self.low: int = 0 | |
| self.high: int = 0 | |
| self.current: int = 0 | |
| self.max_bit: int = -1 | |
| self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. | |
| # Following is for debugging | |
| self._dbg: tp.List[tp.Any] = [] | |
| self._dbg2: tp.List[tp.Any] = [] | |
| self._last: tp.Any = None | |
| def delta(self) -> int: | |
| return self.high - self.low + 1 | |
| def _flush_common_prefix(self): | |
| # Given the current range [L, H], if both have a common prefix, | |
| # we know we can remove it from our representation to avoid handling large numbers. | |
| while self.max_bit >= 0: | |
| b1 = self.low >> self.max_bit | |
| b2 = self.high >> self.max_bit | |
| if b1 == b2: | |
| self.low -= b1 << self.max_bit | |
| self.high -= b1 << self.max_bit | |
| self.current -= b1 << self.max_bit | |
| assert self.high >= self.low | |
| assert self.low >= 0 | |
| self.max_bit -= 1 | |
| else: | |
| break | |
| def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: | |
| """Pull a symbol, reading as many bits from the stream as required. | |
| This returns `None` when the stream has been exhausted. | |
| Args: | |
| quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` | |
| to build this from your pdf estimate. This must be **exatly** | |
| the same cdf as the one used at encoding time. | |
| """ | |
| while self.delta < 2**self.total_range_bits: | |
| bit = self.unpacker.pull() | |
| if bit is None: | |
| return None | |
| self.low *= 2 | |
| self.high = self.high * 2 + 1 | |
| self.current = self.current * 2 + bit | |
| self.max_bit += 1 | |
| def bin_search(low_idx: int, high_idx: int): | |
| # Binary search is not just for coding interviews :) | |
| if high_idx < low_idx: | |
| raise RuntimeError("Binary search failed") | |
| mid = (low_idx + high_idx) // 2 | |
| range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 | |
| range_high = quantized_cdf[mid].item() - 1 | |
| effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) | |
| effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) | |
| low = effective_low + self.low | |
| high = effective_high + self.low | |
| if self.current >= low: | |
| if self.current <= high: | |
| return (mid, low, high, self.current) | |
| else: | |
| return bin_search(mid + 1, high_idx) | |
| else: | |
| return bin_search(low_idx, mid - 1) | |
| self._last = (self.low, self.high, self.current, self.max_bit) | |
| sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) | |
| self._dbg.append((self.low, self.high, self.current)) | |
| self._flush_common_prefix() | |
| self._dbg2.append((self.low, self.high, self.current)) | |
| return sym | |
| def test(): | |
| torch.manual_seed(1234) | |
| random.seed(1234) | |
| for _ in range(4): | |
| pdfs = [] | |
| cardinality = random.randrange(4000) | |
| steps = random.randrange(100, 500) | |
| fo = io.BytesIO() | |
| encoder = ArithmeticCoder(fo) | |
| symbols = [] | |
| for step in range(steps): | |
| pdf = torch.softmax(torch.randn(cardinality), dim=0) | |
| pdfs.append(pdf) | |
| q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) | |
| symbol = torch.multinomial(pdf, 1).item() | |
| symbols.append(symbol) | |
| encoder.push(symbol, q_cdf) | |
| encoder.flush() | |
| fo.seek(0) | |
| decoder = ArithmeticDecoder(fo) | |
| for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): | |
| q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) | |
| decoded_symbol = decoder.pull(q_cdf) | |
| assert decoded_symbol == symbol, idx | |
| assert decoder.pull(torch.zeros(1)) is None | |
| if __name__ == "__main__": | |
| test() | |