# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import collections import unittest import numpy as np from fairseq.data import ListDataset, ResamplingDataset class TestResamplingDataset(unittest.TestCase): def setUp(self): self.strings = ["ab", "c", "def", "ghij"] self.weights = [4.0, 2.0, 7.0, 1.5] self.size_ratio = 2 self.dataset = ListDataset( self.strings, np.array([len(s) for s in self.strings]) ) def _test_common(self, resampling_dataset, iters): assert len(self.dataset) == len(self.strings) == len(self.weights) assert len(resampling_dataset) == self.size_ratio * len(self.strings) results = {"ordered_by_size": True, "max_distribution_diff": 0.0} totalfreqs = 0 freqs = collections.defaultdict(int) for epoch_num in range(iters): resampling_dataset.set_epoch(epoch_num) indices = resampling_dataset.ordered_indices() assert len(indices) == len(resampling_dataset) prev_size = -1 for i in indices: cur_size = resampling_dataset.size(i) # Make sure indices map to same sequences within an epoch assert resampling_dataset[i] == resampling_dataset[i] # Make sure length of sequence is correct assert cur_size == len(resampling_dataset[i]) freqs[resampling_dataset[i]] += 1 totalfreqs += 1 if prev_size > cur_size: results["ordered_by_size"] = False prev_size = cur_size assert set(freqs.keys()) == set(self.strings) for s, weight in zip(self.strings, self.weights): freq = freqs[s] / totalfreqs expected_freq = weight / sum(self.weights) results["max_distribution_diff"] = max( results["max_distribution_diff"], abs(expected_freq - freq) ) return results def test_resampling_dataset_batch_by_size_false(self): resampling_dataset = ResamplingDataset( self.dataset, self.weights, size_ratio=self.size_ratio, batch_by_size=False, seed=0, ) results = self._test_common(resampling_dataset, iters=1000) # For batch_by_size = False, the batches should be returned in # arbitrary order of size. assert not results["ordered_by_size"] # Allow tolerance in distribution error of 2%. assert results["max_distribution_diff"] < 0.02 def test_resampling_dataset_batch_by_size_true(self): resampling_dataset = ResamplingDataset( self.dataset, self.weights, size_ratio=self.size_ratio, batch_by_size=True, seed=0, ) results = self._test_common(resampling_dataset, iters=1000) # For batch_by_size = True, the batches should be returned in # increasing order of size. assert results["ordered_by_size"] # Allow tolerance in distribution error of 2%. assert results["max_distribution_diff"] < 0.02 if __name__ == "__main__": unittest.main()