| import json |
| import base64 |
| from offline_utils import packed_bytes_to_pseudo |
|
|
| def pack_compressed_spans(data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): |
| """ |
| Convert consecutive compressed values into larger integers. |
| |
| Args: |
| data: List of integers where 0-255 are raw bytes, compression_offset+ are compressed bytes |
| bits_per_compressed: Number of bits to use for each packed value |
| compression_bit_threshold: Number of bits each compressed value actually uses |
| compression_offset: Offset that marks start of compressed values (default 256) |
| |
| Returns: |
| List with consecutive compressed spans packed into larger integers |
| """ |
| if not data: |
| return [] |
| |
| result = [] |
| i = 0 |
|
|
| assert compression_bit_threshold % bits_per_compressed == 0, "compression_bit_threshold must be divisible by bits_per_compressed" |
| packing_mask = (1 << bits_per_compressed) - 1 |
| compression_mask = (1 << compression_bit_threshold) - 1 |
| |
| |
| padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 |
| padded_mask = (1 << padded_compression_bit_threshold) - 1 |
|
|
| padding_bits = padded_compression_bit_threshold - compression_bit_threshold |
| |
| while i < len(data): |
| if data[i] >= compression_offset: |
| |
| span_start = i |
| while i < len(data) and data[i] >= compression_offset: |
| i += 1 |
| |
| |
| compressed_span = data[span_start:i] |
| |
| base_values = [x - compression_offset for x in compressed_span] |
| |
| |
| bit_buffer = 0 |
| bits_in_buffer = 0 |
| packed_values = [] |
| |
| for val in base_values: |
| |
| bit_buffer = (bit_buffer << 8) | val |
| bits_in_buffer += 8 |
| |
| |
| while bits_in_buffer >= padded_compression_bit_threshold: |
| shift_amount = bits_in_buffer - padded_compression_bit_threshold |
| padded_val = (bit_buffer >> shift_amount) & padded_mask |
| |
| |
| bit_buffer &= (1 << shift_amount) - 1 |
| bits_in_buffer -= padded_compression_bit_threshold |
| |
| |
| extracted_val = (padded_val >> padding_bits) & compression_mask |
| |
| pack_buffer = extracted_val |
| pack_bits = compression_bit_threshold |
| |
| |
| while pack_bits >= bits_per_compressed: |
| pack_shift = pack_bits - bits_per_compressed |
| packed_val = (pack_buffer >> pack_shift) & packing_mask |
| packed_values.append(packed_val + compression_offset) |
| |
| |
| pack_buffer &= (1 << pack_shift) - 1 |
| pack_bits -= bits_per_compressed |
| |
| assert bits_in_buffer == 0, "bits_in_buffer must be 0 after processing compressed span" |
| assert pack_bits == 0, "pack_bits must be 0 after packing" |
| |
| result.extend(packed_values) |
| else: |
| |
| result.append(data[i]) |
| i += 1 |
| |
| return result |
|
|
| def unpack_compressed_spans(packed_data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256): |
| """ |
| Reverse operation: unpack larger integers back to consecutive compressed bytes. |
| |
| Args: |
| packed_data: List with packed compressed spans |
| bits_per_compressed: Number of bits used for packing |
| compression_bit_threshold: Number of bits each compressed value actually uses |
| compression_offset: Offset used for compressed values |
| |
| Returns: |
| Original format with consecutive compressed bytes |
| """ |
| result = [] |
| i = 0 |
| |
| |
| padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8 |
| padding_bits = padded_compression_bit_threshold - compression_bit_threshold |
| |
| while i < len(packed_data): |
| if packed_data[i] >= compression_offset: |
| |
| span_start = i |
| while i < len(packed_data) and packed_data[i] >= compression_offset: |
| i += 1 |
| |
| packed_span = packed_data[span_start:i] |
| base_values = [x - compression_offset for x in packed_span] |
| |
| |
| unpacked_bytes = [] |
| bit_buffer = 0 |
| bits_in_buffer = 0 |
| |
| for val in base_values: |
| |
| bit_buffer = (bit_buffer << bits_per_compressed) | val |
| bits_in_buffer += bits_per_compressed |
| |
| |
| while bits_in_buffer >= compression_bit_threshold: |
| |
| shift_amount = bits_in_buffer - compression_bit_threshold |
| compressed_val = (bit_buffer >> shift_amount) & ((1 << compression_bit_threshold) - 1) |
| |
| |
| bit_buffer &= (1 << shift_amount) - 1 |
| bits_in_buffer -= compression_bit_threshold |
| |
| |
| padded_val = compressed_val << padding_bits |
| |
| |
| bytes_needed = padded_compression_bit_threshold // 8 |
| for byte_idx in range(bytes_needed): |
| shift = (bytes_needed - 1 - byte_idx) * 8 |
| byte_val = (padded_val >> shift) & 0xFF |
| unpacked_bytes.append(byte_val + compression_offset) |
| |
| |
| assert bits_in_buffer == 0, "bits_in_buffer must be 0 after unpacking compressed span" |
| |
| result.extend(unpacked_bytes) |
| else: |
| |
| result.append(packed_data[i]) |
| i += 1 |
| |
| return result |
|
|
| def run_test_case(test_name: str, data: list, bits_per_compressed: int, compression_bit_threshold: int): |
| """Run a single test case with comprehensive validation.""" |
| print(f"π§ͺ {test_name}") |
| print(f" Original: {data}") |
| print(f" Config: bits_per_compressed={bits_per_compressed}, compression_bit_threshold={compression_bit_threshold}") |
| |
| try: |
| |
| packed = pack_compressed_spans(data, bits_per_compressed, compression_bit_threshold) |
| print(f" Packed: {packed}") |
| |
| |
| unpacked = unpack_compressed_spans(packed, bits_per_compressed, compression_bit_threshold) |
| print(f" Unpacked: {unpacked}") |
| |
| |
| success = data == unpacked |
| print(f" Result: {'β
PASS' if success else 'β FAIL'}") |
| |
| |
| original_compressed = len([x for x in data if x >= 256]) |
| packed_compressed = len([x for x in packed if x >= 256]) |
| if original_compressed > 0: |
| ratio = original_compressed / packed_compressed if packed_compressed > 0 else 0 |
| print(f" Stats: {original_compressed} β {packed_compressed} compressed values ({ratio:.2f}x)") |
| |
| return success |
| |
| except Exception as e: |
| print(f" Result: β ERROR: {e}") |
| return False |
|
|
|
|
| def test_packing_comprehensive(): |
| from m1_compression import utils |
| import random |
| def random_bytes_generator(n: int, bit_threshold: int): |
| ret = [] |
| length = random.randint(n // 2, n) |
| for _ in range(length): |
| bits = "" |
| for _ in range(bit_threshold): |
| bits += "0" if random.random() < 0.5 else "1" |
| compressed_bytes, _ = utils.bits_to_bytes_padding_to_threshold(bits, bit_threshold) |
| ret.extend([c + 256 for c in list(compressed_bytes)]) |
| ret.extend([random.randint(0, 255)]) |
| return ret |
| |
| """Comprehensive test suite for packing functions.""" |
| print("=" * 60) |
| print("π COMPREHENSIVE PACKING TESTS") |
| print("=" * 60) |
| |
| test_results = [] |
| |
| |
| test_results.append(run_test_case( |
| "Basic 16-bit packing (no padding)", |
| random_bytes_generator(100, 16), |
| bits_per_compressed=16, |
| compression_bit_threshold=16 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "12-bit values with 4-bit padding", |
| random_bytes_generator(100, 12), |
| bits_per_compressed=12, |
| compression_bit_threshold=12 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "20-bit values with 4-bit padding", |
| random_bytes_generator(100, 20), |
| bits_per_compressed=20, |
| compression_bit_threshold=20 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "Single compressed byte", |
| [100, 256, 200], |
| bits_per_compressed=8, |
| compression_bit_threshold=8 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "No compressed bytes", |
| [100, 200, 50, 150], |
| bits_per_compressed=16, |
| compression_bit_threshold=16 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "All compressed bytes", |
| [256, 257, 258, 259, 260, 261], |
| bits_per_compressed=8, |
| compression_bit_threshold=8 |
| )) |
| print() |
| |
| |
| test_results.append(run_test_case( |
| "24-bit to 12-bit packing (2:1 ratio)", |
| random_bytes_generator(100, 24), |
| bits_per_compressed=12, |
| compression_bit_threshold=24 |
| )) |
| print() |
| |
| |
| passed = sum(test_results) |
| total = len(test_results) |
| print("=" * 60) |
| print(f"π TEST SUMMARY: {passed}/{total} tests passed") |
| print("=" * 60) |
| |
| if passed == total: |
| print("π All tests passed! The implementation is working correctly.") |
| else: |
| print("β οΈ Some tests failed. Please review the implementation.") |
| |
| return passed, total |
|
|
| def test_real_data(): |
| print("=" * 40) |
| print("π§ REAL DATA TESTS") |
| print("=" * 40) |
|
|
| key = "m1_ac_ow20_escapefb-False_iterative-True" |
| |
| with open("output_compress/m1.chunk.0_out_0_out_0_writer_0.jsonl", "r") as f: |
| for i, line in enumerate(f): |
| data = json.loads(line) |
| |
| |
| |
| |
| |
|
|
| |
| key_splits = key.split("_") |
| bit_threshold = None |
| for k in key_splits: |
| if k.startswith( |
| |
| ): |
| bit_threshold = int(k[len("ow"):]) |
| break |
| assert bit_threshold is not None |
| print(f"Bit threshold: {bit_threshold}") |
|
|
| |
| original_bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key])) |
|
|
| run_test_case( |
| f"Packing {bit_threshold}-bit values", |
| original_bytes_array, |
| 10, |
| bit_threshold |
| ) |
|
|
| if i > 4: |
| break |
|
|
|
|
| def test_error_conditions(): |
| """Test error conditions and edge cases.""" |
| print("\nπ§ ERROR CONDITION TESTS") |
| print("=" * 40) |
| |
| |
| try: |
| pack_compressed_spans([256, 257], 10, 15) |
| print("β Should have failed on invalid bit alignment") |
| except AssertionError: |
| print("β
Correctly caught invalid bit alignment") |
| |
| |
| result = pack_compressed_spans([], 16, 16) |
| print(f"β
Empty data handling: {result == []}") |
| |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| |
| passed, total = test_packing_comprehensive() |
| |
| |
| if passed == total: |
| print("π ALL TESTS COMPLETED SUCCESSFULLY!") |
| else: |
| print("π₯ SOME TESTS FAILED!") |
|
|
| test_real_data() |
| |
| |
| test_error_conditions() |
|
|