|
|
|
|
|
|
|
|
|
|
|
import os |
|
import random |
|
import string |
|
import typing as tp |
|
import unittest |
|
from collections import Counter |
|
from tempfile import NamedTemporaryFile, TemporaryDirectory |
|
|
|
from fairseq.data import Dictionary, indexed_dataset |
|
from fairseq.data.huffman import ( |
|
HuffmanCodeBuilder, |
|
HuffmanCoder, |
|
HuffmanMMapIndexedDataset, |
|
HuffmanMMapIndexedDatasetBuilder, |
|
) |
|
|
|
POPULATION = string.ascii_letters + string.digits |
|
|
|
|
|
def make_sentence() -> tp.List[str]: |
|
length = random.randint(10, 50) |
|
return random.choices( |
|
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1) |
|
) |
|
|
|
|
|
def make_data(length=1000) -> tp.List[tp.List[str]]: |
|
return ( |
|
[make_sentence() for _ in range(0, length)] |
|
|
|
+ [list(string.ascii_letters), list(string.digits)] |
|
) |
|
|
|
|
|
def make_counts(data: tp.List[tp.List[str]]) -> Counter: |
|
return Counter([symbol for sentence in data for symbol in sentence]) |
|
|
|
|
|
def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder: |
|
builder = HuffmanCodeBuilder() |
|
for sentence in data: |
|
builder.add_symbols(*sentence) |
|
return builder |
|
|
|
|
|
class TestCodeBuilder(unittest.TestCase): |
|
def test_code_builder_can_count(self): |
|
data = make_data() |
|
counts = make_counts(data) |
|
builder = make_code_builder(data) |
|
|
|
self.assertEqual(builder.symbols, counts) |
|
|
|
def test_code_builder_can_add(self): |
|
data = make_data() |
|
counts = make_counts(data) |
|
builder = make_code_builder(data) |
|
|
|
new_builder = builder + builder |
|
|
|
self.assertEqual(new_builder.symbols, counts + counts) |
|
|
|
def test_code_builder_can_io(self): |
|
data = make_data() |
|
builder = make_code_builder(data) |
|
|
|
with NamedTemporaryFile() as tmp_fp: |
|
builder.to_file(tmp_fp.name) |
|
other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name) |
|
|
|
self.assertEqual(builder.symbols, other_builder.symbols) |
|
|
|
|
|
class TestCoder(unittest.TestCase): |
|
def test_coder_can_io(self): |
|
data = make_data() |
|
builder = make_code_builder(data) |
|
coder = builder.build_code() |
|
|
|
with NamedTemporaryFile() as tmp_fp: |
|
coder.to_file(tmp_fp.name) |
|
other_coder = HuffmanCoder.from_file(tmp_fp.name) |
|
|
|
self.assertEqual(coder, other_coder) |
|
|
|
def test_coder_can_encode_decode(self): |
|
data = make_data() |
|
builder = make_code_builder(data) |
|
coder = builder.build_code() |
|
|
|
encoded = [coder.encode(sentence) for sentence in data] |
|
decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded] |
|
|
|
self.assertEqual(decoded, data) |
|
|
|
unseen_data = make_data() |
|
unseen_encoded = [coder.encode(sentence) for sentence in unseen_data] |
|
unseen_decoded = [ |
|
[n.symbol for n in coder.decode(enc)] for enc in unseen_encoded |
|
] |
|
self.assertEqual(unseen_decoded, unseen_data) |
|
|
|
|
|
def build_dataset(prefix, data, coder): |
|
with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder: |
|
for sentence in data: |
|
builder.add_item(sentence) |
|
|
|
|
|
def sizes(data): |
|
return [len(sentence) for sentence in data] |
|
|
|
|
|
class TestHuffmanDataset(unittest.TestCase): |
|
def test_huffman_can_encode_decode(self): |
|
data = make_data() |
|
builder = make_code_builder(data) |
|
coder = builder.build_code() |
|
|
|
with TemporaryDirectory() as dirname: |
|
prefix = os.path.join(dirname, "test1") |
|
build_dataset(prefix, data, coder) |
|
dataset = HuffmanMMapIndexedDataset(prefix) |
|
|
|
self.assertEqual(len(dataset), len(data)) |
|
decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))] |
|
|
|
self.assertEqual(decoded, data) |
|
data_sizes = [i.item() for i in dataset.sizes] |
|
self.assertEqual(data_sizes, sizes(data)) |
|
|
|
def test_huffman_compresses(self): |
|
data = make_data() |
|
builder = make_code_builder(data) |
|
coder = builder.build_code() |
|
|
|
with TemporaryDirectory() as dirname: |
|
prefix = os.path.join(dirname, "huffman") |
|
build_dataset(prefix, data, coder) |
|
|
|
prefix_mmap = os.path.join(dirname, "mmap") |
|
mmap_builder = indexed_dataset.make_builder( |
|
indexed_dataset.data_file_path(prefix_mmap), |
|
"mmap", |
|
vocab_size=len(POPULATION), |
|
) |
|
dictionary = Dictionary() |
|
for c in POPULATION: |
|
dictionary.add_symbol(c) |
|
dictionary.finalize() |
|
for sentence in data: |
|
mmap_builder.add_item(dictionary.encode_line(" ".join(sentence))) |
|
mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap)) |
|
|
|
huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size |
|
mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size |
|
self.assertLess(huff_size, mmap_size) |
|
|
|
def test_huffman_can_append(self): |
|
data1 = make_data() |
|
builder = make_code_builder(data1) |
|
coder = builder.build_code() |
|
|
|
with TemporaryDirectory() as dirname: |
|
prefix1 = os.path.join(dirname, "test1") |
|
build_dataset(prefix1, data1, coder) |
|
|
|
data2 = make_data() |
|
prefix2 = os.path.join(dirname, "test2") |
|
build_dataset(prefix2, data2, coder) |
|
|
|
prefix3 = os.path.join(dirname, "test3") |
|
|
|
with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder: |
|
builder.append(prefix1) |
|
builder.append(prefix2) |
|
|
|
dataset = HuffmanMMapIndexedDataset(prefix3) |
|
|
|
self.assertEqual(len(dataset), len(data1) + len(data2)) |
|
|
|
decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))] |
|
self.assertEqual(decoded1, data1) |
|
|
|
decoded2 = [ |
|
list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset)) |
|
] |
|
self.assertEqual(decoded2, data2) |
|
|
|
data_sizes = [i.item() for i in dataset.sizes] |
|
self.assertEqual(data_sizes[: len(data1)], sizes(data1)) |
|
self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2)) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|