Spaces:
Runtime error
Runtime error
File size: 5,365 Bytes
12bfd03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io
import json
import struct
import typing as tp
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct('!4sBI')
_ENCODEC_MAGIC = b'ECDC'
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
meta_dumped = json.dumps(metadata).encode('utf-8')
version = 0
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version,
len(meta_dumped))
fo.write(header)
fo.write(meta_dumped)
fo.flush()
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
buf = b""
while len(buf) < size:
new_buf = fo.read(size)
if not new_buf:
raise EOFError("Impossible to read enough data from the stream, "
f"{size} bytes remaining.")
buf += new_buf
size -= len(new_buf)
return buf
def read_ecdc_header(fo: tp.IO[bytes]):
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
if magic != _ENCODEC_MAGIC:
raise ValueError("File is not in ECDC format.")
if version != 0:
raise ValueError("Version not supported.")
meta_bytes = _read_exactly(fo, meta_size)
return json.loads(meta_bytes.decode('utf-8'))
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self._current_value += (value << self._current_bits)
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xff
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> tp.Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
return out
def test():
import torch
torch.manual_seed(1234)
for rep in range(4):
length: int = torch.randint(10, 2_000, (1, )).item()
bits: int = torch.randint(1, 16, (1, )).item()
tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist()
rebuilt: tp.List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt),
len(tokens), bits)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
if __name__ == '__main__':
test()
|