Byte-lingua-code / m1_compression /enumerative_coder_simple.py
2ira's picture
offline_compression_graph_code
72c0672 verified
from typing import List, Tuple
import itertools
import math
def _pack_rank_ids(buf: List[int], rank_bitlength: int) -> List[int]:
per_b = 8 // rank_bitlength
mask = (1 << rank_bitlength) - 1
out_b = []
it = iter(buf)
while True:
chunk = list(itertools.islice(it, per_b))
if not chunk:
break
byte_val = 0
for p, idx in enumerate(chunk):
byte_val |= (idx & mask) << (p * rank_bitlength)
out_b.append(byte_val)
return out_b
def _unpack_rank_ids(payload: List[int], run_len: int, rank_bitlength: int):
mask = (1 << rank_bitlength) - 1
byte_iter = iter(payload)
cur_byte = next(byte_iter)
filled = 8
for _ in range(run_len):
if filled == 0:
cur_byte = next(byte_iter)
filled = 8
rank_id = cur_byte & mask
cur_byte >>= rank_bitlength
filled -= rank_bitlength
yield rank_id
class SimpleAdaptiveRankCodec:
def __init__(
self,
top_k: int = 4,
tau: float = 0.5,
min_run: int = 3,
max_run: int = 255,
sentinel_rle: int = 256,
sentinel_rank_run: int = 257,
):
self.top_k = top_k
self.tau = tau
self.min_run = min_run
self.max_run = max_run
self.raw_byte_offset = 256
self.rank_bitlength = max(1, (top_k - 1).bit_length())
assert self.rank_bitlength <= 8 and 8 % self.rank_bitlength == 0, (
f"rank_bitlength must be between 1 and 8 and must divide 8, got {self.rank_bitlength}"
f"top_k: {top_k}"
)
self.ranks_per_byte = 8 // self.rank_bitlength
self.sentinel_rle = sentinel_rle
self.sentinel_rank_run = sentinel_rank_run
def encode_window(
self,
tokens: List[int],
repeat_probs: List[float],
ranks: List[int],
) -> List[int]:
"""Return a list of ints: raw bytes 0-255 and sentinel events ≥256."""
assert len(tokens) == len(repeat_probs) == len(ranks)
rank_buf: List[int] = []
out: List[int] = [tokens[0]]
i, n = 1, len(tokens)
def flush_rank_buf():
if not rank_buf:
return
out.append(self.sentinel_rank_run)
out.append(len(rank_buf))
out.extend(_pack_rank_ids(rank_buf, self.rank_bitlength))
rank_buf.clear()
while i < n:
tok = tokens[i]
# --- RLE probe (uses *current* token prob) --------------------
run = 1
while (i + run < n and
tokens[i + run] == tok and
repeat_probs[i + run] >= self.tau):
run += 1
if run >= self.min_run:
flush_rank_buf()
out.extend([self.sentinel_rle, run, tok])
i += run
continue
if ranks[i] < self.top_k:
rank_buf.append(ranks[i])
else:
# the current token is not in top-K,
# so we escape to a raw byte
flush_rank_buf()
out.append(tok)
i += 1
flush_rank_buf()
return out
def encoding_to_pseudo_bytes(self, enc: list[int]) -> list[int]:
# NOTE: this function is not expected to be lossless, that is,
# we cannot reconstruct the original encoding from the pseudo-bytes
out: list[int] = []
i = 0
while i < len(enc):
tok = enc[i]
i += 1
if tok < self.raw_byte_offset:
out.append(tok)
elif tok == self.sentinel_rle:
run = enc[i]
raw = enc[i+1]
i += 2
run = min(run, self.max_run)
# we mark the run length from 512 to 256
out.extend([self.raw_byte_offset + self.raw_byte_offset - run, raw])
elif tok == self.sentinel_rank_run:
length = enc[i]
i += 1
n_bytes = math.ceil(length / self.ranks_per_byte)
for _ in range(n_bytes):
pb = enc[i] + self.raw_byte_offset
out.append(pb)
i += 1
else:
raise ValueError(f"unknown token {tok}")
return out
def pseudo_bytes_to_encoding(self, pb: list[int], original_encoding: list[int]) -> list[int]:
# NOTE: we do not expect the encoding-to-pseudo-bytes conversion to be lossless,
# so we need to pass the original encoding to reconstruct the original encoding
# this function is just for sanity check
raise NotImplementedError("Not implemented")
def decode_window(
self,
stream: List[int],
original_len: int,
topk_symbols: List[List[int]],
) -> List[int]:
"""
`topk_symbols[pos][idx]` must give the byte value (0-255) that
corresponds to rank `idx` at position `pos`, e.g. recomputed from
the helper LM during decoding.
"""
out: List[int] = []
# position in input stream
i = 0
# position in output tokens
pos = 0
while pos < original_len:
tok = stream[i]
i += 1
if tok < 256:
out.append(tok)
pos += 1
elif tok == self.sentinel_rle:
run_len = stream[i]
raw = stream[i+1]
i += 2
out.extend([raw] * run_len)
pos += run_len
elif tok == self.sentinel_rank_run:
run_len = stream[i]
i += 1
bytes_needed = math.ceil(run_len / self.ranks_per_byte)
payload = stream[i: i + bytes_needed]
i += bytes_needed
for rank_id in _unpack_rank_ids(payload, run_len, self.rank_bitlength):
sym = topk_symbols[pos][rank_id]
out.append(sym)
pos += 1
else:
raise ValueError(f"Unknown sentinel {tok}")
return out[:original_len]
if __name__ == "__main__":
import torch, random
random.seed(0)
T, K = 384, 13 # demonstrate non-power-of-two K
tokens = torch.randint(0, 32, (T,)).tolist()
repeat_probs = torch.rand(T).tolist()
ranks = torch.randint(0, K + 5, (T,)).tolist() # some ranks ≥K → raw
ranks = [r if r < K else K for r in ranks]
# fake LM top-K table for demo: identity mapping
topk = [[tokens[t]] * K for t in range(T)]
codec = SimpleAdaptiveRankCodec(top_k=K, tau=0.00)
enc = codec.encode_window(tokens, repeat_probs, ranks)
dec = codec.decode_window(enc, T, topk)
print(f"raw={T} encoded={len(enc)} ratio={len(enc)/T:.2f}")
assert dec == tokens
print("✓ window-enc-dec round-trip passes")