Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import torch | |
from fairseq.data import MonolingualDataset | |
from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig | |
from tests import utils as test_utils | |
class TestLMContextWindow(unittest.TestCase): | |
def test_eval_dataloader(self): | |
dictionary = test_utils.dummy_dictionary(10) | |
assert len(dictionary) == 14 # 4 extra special symbols | |
assert dictionary.pad() == 1 | |
dataset = test_utils.TestDataset([ | |
torch.tensor([4, 5, 6, 7], dtype=torch.long), | |
torch.tensor([8, 9, 10, 11], dtype=torch.long), | |
torch.tensor([12, 13], dtype=torch.long), | |
]) | |
dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) | |
config = LanguageModelingConfig(tokens_per_sample=4) | |
task = LanguageModelingTask(config, dictionary) | |
eval_dataloader = task.eval_lm_dataloader( | |
dataset=dataset, | |
batch_size=1, | |
context_window=2, | |
) | |
batch = next(eval_dataloader) | |
assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] | |
assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] | |
batch = next(eval_dataloader) | |
assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] | |
assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] | |
batch = next(eval_dataloader) | |
assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] | |
assert batch["target"][0].tolist() == [1, 1, 12, 13] | |
if __name__ == "__main__": | |
unittest.main() | |