Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Checkpoint helper functions for managing checkpoints. | |
Supports marking checkpoints as pinned to exclude them from the checkpointer | |
removal process. | |
""" | |
import os | |
from absl import logging | |
from tensorflow.io import gfile | |
# PINNED file in the checkpoint directory indicates that the checkpoint should | |
# not be removed during the automatic pruning of old checkpoints. | |
_PINNED_CHECKPOINT_FILENAME = 'PINNED' | |
def pinned_checkpoint_filepath(ckpt_dir: str) -> str: | |
"""Full path of the pinned checkpoint file.""" | |
return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME) | |
def is_pinned_checkpoint(ckpt_dir: str) -> bool: | |
"""Returns whether the checkpoint is pinned, and should NOT be removed.""" | |
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) | |
if gfile.exists(pinned_ckpt_file): | |
return True | |
return False | |
def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None: | |
"""Pin a checkpoint so it does not get deleted by the normal pruning process. | |
Creates a PINNED file in the checkpoint directory to indicate the checkpoint | |
should be excluded from the deletion of old checkpoints. | |
Args: | |
ckpt_dir: The checkpoint step dir that is to be always kept. | |
txt: Text to be written into the checkpoints ALWAYS_KEEP me file. | |
""" | |
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) | |
with gfile.GFile(pinned_ckpt_file, 'w') as f: | |
logging.debug('Write %s file : %s.', pinned_ckpt_file, txt) | |
f.write(txt) | |
def unpin_checkpoint(ckpt_dir: str) -> None: | |
"""Removes the pinned status of the checkpoint so it is open for deletion.""" | |
if not is_pinned_checkpoint(ckpt_dir): | |
logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir) | |
return | |
try: | |
pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) | |
logging.debug('Remove %s file.', pinned_ckpt_file) | |
gfile.rmtree(pinned_ckpt_file) | |
except IOError: | |
logging.exception('Failed to unpin %s', ckpt_dir) | |
def remove_checkpoint_dir(ckpt_dir: str) -> None: | |
"""Removes the checkpoint dir if it is not pinned.""" | |
if not is_pinned_checkpoint(ckpt_dir): | |
logging.info('Deleting checkpoint: %s', ckpt_dir) | |
gfile.rmtree(ckpt_dir) | |
else: | |
logging.info('Keeping pinned checkpoint: %s', ckpt_dir) | |
def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: | |
"""Removes dataset checkpoints if the checkpoint is not pinned.""" | |
if not is_pinned_checkpoint(ckpt_dir): | |
train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*') | |
logging.info('Deleting dataset checkpoint: %s', train_ds_pattern) | |
for file in gfile.glob(train_ds_pattern): | |
gfile.remove(file) | |
else: | |
logging.info('Keeping pinned checkpoint: %s', ckpt_dir) | |