# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from fairseq.data import LanguagePairDataset, TokenBlockDataset from fairseq.data.concat_dataset import ConcatDataset from tests.test_train import mock_dict class TestConcatDataset(unittest.TestCase): def setUp(self): d = mock_dict() tokens_1 = torch.LongTensor([1]).view(1, -1) tokens_ds1 = TokenBlockDataset( tokens_1, sizes=[tokens_1.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, ) self.dataset_1 = LanguagePairDataset( tokens_ds1, tokens_ds1.sizes, d, shuffle=False ) tokens_2 = torch.LongTensor([2]).view(1, -1) tokens_ds2 = TokenBlockDataset( tokens_2, sizes=[tokens_2.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, ) self.dataset_2 = LanguagePairDataset( tokens_ds2, tokens_ds2.sizes, d, shuffle=False ) def test_concat_dataset_basics(self): d = ConcatDataset([self.dataset_1, self.dataset_2]) assert len(d) == 2 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 2 d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) assert len(d) == 3 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 2 assert d[2]["source"][0] == 2 d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) assert len(d) == 3 assert d[0]["source"][0] == 1 assert d[1]["source"][0] == 1 assert d[2]["source"][0] == 2