# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for t5x.checkpoint_utils.""" import os import traceback from absl.testing import absltest from t5x import checkpoint_utils from tensorflow.io import gfile TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") class CheckpointsUtilsTest(absltest.TestCase): def setUp(self): super().setUp() self.checkpoints_dir = self.create_tempdir() self.ckpt_dir_path = self.checkpoints_dir.full_path self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED") self.checkpoints_dir.create_file("checkpoint") # Create a `train_ds` file representing the dataset checkpoint. train_ds_basename = "train_ds-00000-of-00001" self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename) self.checkpoints_dir.create_file(train_ds_basename) def test_always_keep_checkpoint_file(self): self.assertEqual( "/path/to/ckpt/dir/PINNED", checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir")) def test_is_pinned_checkpoint_false_by_default(self): # Ensure regular checkpoint without PINNED file. self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) # Validate checkpoints are not pinned by default. self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) def test_is_pinned_checkpoint(self): # Ensure the checkpoint directory as pinned. pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir") pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED") self.assertTrue(gfile.exists(pinned_file)) # Test and validate. self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata)) def test_is_pinned_missing_ckpt(self): self.assertFalse( checkpoint_utils.is_pinned_checkpoint( os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist"))) def test_pin_checkpoint(self): # Ensure directory isn't already pinned. self.assertFalse(gfile.exists(self.pinned_ckpt_file)) # Test. checkpoint_utils.pin_checkpoint(self.ckpt_dir_path) # Validate. self.assertTrue(gfile.exists(self.pinned_ckpt_file)) with open(self.pinned_ckpt_file) as f: self.assertEqual("1", f.read()) def test_pin_checkpoint_txt(self): checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED") self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) with open(self.pinned_ckpt_file) as f: self.assertEqual("TEXT_IN_PINNED", f.read()) def test_unpin_checkpoint(self): # Mark the checkpoint directory as pinned. self.checkpoints_dir.create_file("PINNED") self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) # Test. checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path) # Validate the "PINNED" checkpoint file got removed. self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) def test_unpin_checkpoint_does_not_exist(self): missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist") self.assertFalse(gfile.exists(missing_ckpt_path)) # Test. Assert does not raise error. try: checkpoint_utils.unpin_checkpoint(missing_ckpt_path) except IOError: # TODO(b/172262005): Remove traceback.format_exc() from the error message. self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc()) def test_remove_checkpoint_dir(self): # Ensure the checkpoint directory is setup. assert gfile.exists(self.ckpt_dir_path) # Test. checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) # Validate the checkpoint directory got removed. self.assertFalse(gfile.exists(self.ckpt_dir_path)) def test_remove_checkpoint_dir_pinned(self): # Mark the checkpoint directory as pinned so it does not get removed. self.checkpoints_dir.create_file("PINNED") # Test. checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) # Validate the checkpoint directory still exists. self.assertTrue(gfile.exists(self.ckpt_dir_path)) def test_remove_dataset_checkpoint(self): # Ensure the checkpoint directory is setup. assert gfile.exists(self.ckpt_dir_path) # Test. checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") # Validate the checkpoint directory got removed. self.assertFalse(gfile.exists(self.train_ds_file)) self.assertTrue(gfile.exists(self.ckpt_dir_path)) def test_remove_dataset_checkpoint_pinned(self): # Mark the checkpoint directory as pinned so it does not get removed. self.checkpoints_dir.create_file("PINNED") # Test. checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") # Validate the checkpoint directory still exists. self.assertTrue(gfile.exists(self.train_ds_file)) self.assertTrue(gfile.exists(self.ckpt_dir_path)) if __name__ == "__main__": absltest.main()