# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" import io import json import struct import typing as tp # format is `ECDC` magic code, followed by the header size as uint32. # Then an uint8 indicates the protocol version (0.) # The header is then provided as json and should contain all required # informations for decoding. A raw stream of bytes is then provided # and should be interpretable using the json header. _encodec_header_struct = struct.Struct('!4sBI') _ENCODEC_MAGIC = b'ECDC' def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): meta_dumped = json.dumps(metadata).encode('utf-8') version = 0 header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped)) fo.write(header) fo.write(meta_dumped) fo.flush() def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: buf = b"" while len(buf) < size: new_buf = fo.read(size) if not new_buf: raise EOFError("Impossible to read enough data from the stream, " f"{size} bytes remaining.") buf += new_buf size -= len(new_buf) return buf def read_ecdc_header(fo: tp.IO[bytes]): header_bytes = _read_exactly(fo, _encodec_header_struct.size) magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) if magic != _ENCODEC_MAGIC: raise ValueError("File is not in ECDC format.") if version != 0: raise ValueError("Version not supported.") meta_bytes = _read_exactly(fo, meta_size) return json.loads(meta_bytes.decode('utf-8')) class BitPacker: """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. Note that for some bandwidth (1.5, 3), the codebook representation will not cover an integer number of bytes. Args: bits (int): number of bits per value that will be pushed. fo (IO[bytes]): file-object to push the bytes to. """ def __init__(self, bits: int, fo: tp.IO[bytes]): self._current_value = 0 self._current_bits = 0 self.bits = bits self.fo = fo def push(self, value: int): """Push a new value to the stream. This will immediately write as many uint8 as possible to the underlying file-object.""" self._current_value += (value << self._current_bits) self._current_bits += self.bits while self._current_bits >= 8: lower_8bits = self._current_value & 0xff self._current_bits -= 8 self._current_value >>= 8 self.fo.write(bytes([lower_8bits])) def flush(self): """Flushes the remaining partial uint8, call this at the end of the stream to encode.""" if self._current_bits: self.fo.write(bytes([self._current_value])) self._current_value = 0 self._current_bits = 0 self.fo.flush() class BitUnpacker: """BitUnpacker does the opposite of `BitPacker`. Args: bits (int): number of bits of the values to decode. fo (IO[bytes]): file-object to push the bytes to. """ def __init__(self, bits: int, fo: tp.IO[bytes]): self.bits = bits self.fo = fo self._mask = (1 << bits) - 1 self._current_value = 0 self._current_bits = 0 def pull(self) -> tp.Optional[int]: """ Pull a single value from the stream, potentially reading some extra bytes from the underlying file-object. Returns `None` when reaching the end of the stream. """ while self._current_bits < self.bits: buf = self.fo.read(1) if not buf: return None character = buf[0] self._current_value += character << self._current_bits self._current_bits += 8 out = self._current_value & self._mask self._current_value >>= self.bits self._current_bits -= self.bits return out def test(): import torch torch.manual_seed(1234) for rep in range(4): length: int = torch.randint(10, 2_000, (1, )).item() bits: int = torch.randint(1, 16, (1, )).item() tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() rebuilt: tp.List[int] = [] buf = io.BytesIO() packer = BitPacker(bits, buf) for token in tokens: packer.push(token) packer.flush() buf.seek(0) unpacker = BitUnpacker(bits, buf) while True: value = unpacker.pull() if value is None: break rebuilt.append(value) assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) # The flushing mechanism might lead to "ghost" values at the end of the stream. assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits) for idx, (a, b) in enumerate(zip(tokens, rebuilt)): assert a == b, (idx, a, b) if __name__ == '__main__': test()