|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import numpy as np |
|
from fairseq.data.data_utils_fast import batch_by_size_fn |
|
from fairseq.data.data_utils_fast import batch_by_size_vec |
|
|
|
|
|
class TestBatchBySize(unittest.TestCase): |
|
@classmethod |
|
def batch_by_size_baseline( |
|
cls, |
|
indices, |
|
num_tokens_vec, |
|
max_tokens, |
|
max_sentences, |
|
bsz_mult, |
|
): |
|
"""Simple, reliable and slow implementation of batch by size """ |
|
batches = [] |
|
start = 0 |
|
while start < len(indices): |
|
for end in range(start + 1, len(indices) + 1): |
|
max_val = max(num_tokens_vec[pos] for pos in range(start, end)) |
|
sent_count = end - start |
|
num_tokens = max_val * sent_count |
|
overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0 |
|
terminate = overflow or end == len(indices) |
|
if overflow: |
|
sent_count -= 1 |
|
if terminate: |
|
if sent_count > bsz_mult: |
|
sent_count = sent_count - sent_count % bsz_mult |
|
batches.append(indices[start : start + sent_count]) |
|
start = start + sent_count |
|
break |
|
return batches |
|
|
|
@classmethod |
|
def _get_error_message( |
|
cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results |
|
): |
|
return f"""Reference batch_by_size implementation should produce |
|
same output as the baseline method. |
|
Params: |
|
max_sentences={max_sentences}, |
|
max_tokens={max_tokens}, |
|
bsz_mult={bsz_mult}, |
|
num_tokens_vec={num_tokens_vec}, |
|
expected_batches={validation}, |
|
returned_batches={results}""" |
|
|
|
def _compare_results( |
|
self, |
|
indices_len, |
|
batch_by_size_impl, |
|
max_sentences, |
|
max_tokens, |
|
bsz_mult, |
|
num_tokens_vec, |
|
): |
|
indices = np.array(list(range(indices_len))) |
|
validation = self.batch_by_size_baseline( |
|
indices, |
|
num_tokens_vec, |
|
max_tokens=max_tokens, |
|
max_sentences=max_sentences, |
|
bsz_mult=bsz_mult, |
|
) |
|
results = batch_by_size_impl( |
|
indices, |
|
num_tokens_vec, |
|
max_tokens=max_tokens, |
|
max_sentences=max_sentences, |
|
bsz_mult=bsz_mult, |
|
) |
|
error_msg = self._get_error_message( |
|
max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results |
|
) |
|
self.assertEqual(len(validation), len(results), error_msg) |
|
for first, second in zip(validation, results): |
|
self.assertTrue(np.array_equal(first, second), error_msg) |
|
|
|
def _run_compare_with_baseline_sweep(self, batch_by_size_impl): |
|
"""Compare reference batch_by_size implementation with batch_by_size_baseline |
|
across a dense grid of hyperparam values""" |
|
MAX_MAX_TOKENS = 10 |
|
NUM_TOKENS_VECS_COUNT = 5 |
|
for indices_len in [10, 11]: |
|
for max_sentences in range(0, indices_len + 2): |
|
for max_tokens in range(0, MAX_MAX_TOKENS): |
|
for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2): |
|
for _ in range(NUM_TOKENS_VECS_COUNT): |
|
num_tokens_vec = np.random.randint( |
|
0, max_tokens + 1, size=indices_len |
|
) |
|
self._compare_results( |
|
indices_len, |
|
batch_by_size_impl, |
|
max_sentences, |
|
max_tokens, |
|
bsz_mult, |
|
num_tokens_vec, |
|
) |
|
|
|
|
|
class TestBatchBySizeVec(TestBatchBySize): |
|
def test_compare_with_baseline(self): |
|
self._run_compare_with_baseline_sweep(batch_by_size_vec) |
|
|
|
|
|
class TestBatchBySizeFn(TestBatchBySize): |
|
def test_compare_with_baseline(self): |
|
def batch_by_size_fn_wrapper( |
|
indices, |
|
num_tokens_vec, |
|
max_tokens, |
|
max_sentences, |
|
bsz_mult, |
|
): |
|
def num_tokens_fn(idx): |
|
return num_tokens_vec[idx] |
|
|
|
return batch_by_size_fn( |
|
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult |
|
) |
|
|
|
self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|