Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import logging | |
import os | |
import tempfile | |
import unittest | |
from io import StringIO | |
from unittest.mock import patch | |
from fairseq import checkpoint_utils | |
from omegaconf import OmegaConf | |
from tests.utils import ( | |
create_dummy_data, | |
preprocess_translation_data, | |
train_translation_model, | |
) | |
class TestCheckpointUtils(unittest.TestCase): | |
def setUp(self): | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
logging.disable(logging.NOTSET) | |
def _train_transformer(self, seed, extra_args=None): | |
if extra_args is None: | |
extra_args = [] | |
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir: | |
create_dummy_data(data_dir) | |
preprocess_translation_data(data_dir) | |
train_translation_model( | |
data_dir, | |
"transformer_iwslt_de_en", | |
[ | |
"--encoder-layers", | |
"3", | |
"--decoder-layers", | |
"3", | |
"--encoder-embed-dim", | |
"8", | |
"--decoder-embed-dim", | |
"8", | |
"--seed", | |
str(seed), | |
] | |
+ extra_args, | |
) | |
yield os.path.join(data_dir, "checkpoint_last.pt") | |
def test_load_model_ensemble_and_task(self): | |
# with contextlib.redirect_stdout(StringIO()): | |
with self._train_transformer(seed=123) as model1: | |
with self._train_transformer(seed=456) as model2: | |
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
filenames=[model1, model2] | |
) | |
self.assertEqual(len(ensemble), 2) | |
# after Transformer has been migrated to Hydra, this will probably | |
# become cfg.common.seed | |
self.assertEqual(ensemble[0].args.seed, 123) | |
self.assertEqual(ensemble[1].args.seed, 456) | |
# the task from the first model should be returned | |
self.assertTrue("seed123" in task.cfg.data) | |
# last cfg is saved | |
self.assertEqual(cfg.common.seed, 456) | |
def test_prune_state_dict(self): | |
with contextlib.redirect_stdout(StringIO()): | |
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"] | |
with self._train_transformer(seed=1, extra_args=extra_args) as model: | |
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
filenames=[model], | |
arg_overrides={ | |
"encoder_layers_to_keep": "0,2", | |
"decoder_layers_to_keep": "1", | |
}, | |
) | |
self.assertEqual(len(ensemble), 1) | |
self.assertEqual(len(ensemble[0].encoder.layers), 2) | |
self.assertEqual(len(ensemble[0].decoder.layers), 1) | |
def test_torch_persistent_save_async(self): | |
state_dict = {} | |
filename = "async_checkpoint.pt" | |
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: | |
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: | |
checkpoint_utils.torch_persistent_save( | |
state_dict, filename, async_write=True | |
) | |
mock_opena.assert_called_with(filename, "wb") | |
mock_save.assert_called() | |
if __name__ == "__main__": | |
unittest.main() | |