| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| | import re |
| | import shutil |
| | import sys |
| |
|
| |
|
| | pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt") |
| | pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt") |
| | pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt") |
| |
|
| |
|
| | def parse_checkpoints(files): |
| | entries = [] |
| | for f in files: |
| | m = pt_regexp_epoch_based.fullmatch(f) |
| | if m is not None: |
| | entries.append((int(m.group(1)), m.group(0))) |
| | else: |
| | m = pt_regexp_update_based.fullmatch(f) |
| | if m is not None: |
| | entries.append((int(m.group(1)), m.group(0))) |
| | return entries |
| |
|
| |
|
| | def last_n_checkpoints(files, n): |
| | entries = parse_checkpoints(files) |
| | return [x[1] for x in sorted(entries, reverse=True)[:n]] |
| |
|
| |
|
| | def every_n_checkpoints(files, n): |
| | entries = parse_checkpoints(files) |
| | return [x[1] for x in sorted(sorted(entries)[::-n])] |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Recursively delete checkpoint files from `root_dir`, " |
| | "but preserve checkpoint_best.pt and checkpoint_last.pt" |
| | ) |
| | ) |
| | parser.add_argument("root_dirs", nargs="*") |
| | parser.add_argument( |
| | "--save-last", type=int, default=0, help="number of last checkpoints to save" |
| | ) |
| | parser.add_argument( |
| | "--save-every", type=int, default=0, help="interval of checkpoints to save" |
| | ) |
| | parser.add_argument( |
| | "--preserve-test", |
| | action="store_true", |
| | help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)", |
| | ) |
| | parser.add_argument( |
| | "--delete-best", action="store_true", help="delete checkpoint_best.pt" |
| | ) |
| | parser.add_argument( |
| | "--delete-last", action="store_true", help="delete checkpoint_last.pt" |
| | ) |
| | parser.add_argument( |
| | "--no-dereference", action="store_true", help="don't dereference symlinks" |
| | ) |
| | args = parser.parse_args() |
| |
|
| | files_to_desymlink = [] |
| | files_to_preserve = [] |
| | files_to_delete = [] |
| | for root_dir in args.root_dirs: |
| | for root, _subdirs, files in os.walk(root_dir): |
| | if args.save_last > 0: |
| | to_save = last_n_checkpoints(files, args.save_last) |
| | else: |
| | to_save = [] |
| | if args.save_every > 0: |
| | to_save += every_n_checkpoints(files, args.save_every) |
| | for file in files: |
| | if not pt_regexp.fullmatch(file): |
| | continue |
| | full_path = os.path.join(root, file) |
| | if ( |
| | not os.path.basename(root).startswith("test_") or args.preserve_test |
| | ) and ( |
| | (file == "checkpoint_last.pt" and not args.delete_last) |
| | or (file == "checkpoint_best.pt" and not args.delete_best) |
| | or file in to_save |
| | ): |
| | if os.path.islink(full_path) and not args.no_dereference: |
| | files_to_desymlink.append(full_path) |
| | else: |
| | files_to_preserve.append(full_path) |
| | else: |
| | files_to_delete.append(full_path) |
| |
|
| | if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: |
| | print("Nothing to do.") |
| | sys.exit(0) |
| |
|
| | files_to_desymlink = sorted(files_to_desymlink) |
| | files_to_preserve = sorted(files_to_preserve) |
| | files_to_delete = sorted(files_to_delete) |
| |
|
| | print("Operations to perform (in order):") |
| | if len(files_to_desymlink) > 0: |
| | for file in files_to_desymlink: |
| | print(" - preserve (and dereference symlink): " + file) |
| | if len(files_to_preserve) > 0: |
| | for file in files_to_preserve: |
| | print(" - preserve: " + file) |
| | if len(files_to_delete) > 0: |
| | for file in files_to_delete: |
| | print(" - delete: " + file) |
| | while True: |
| | resp = input("Continue? (Y/N): ") |
| | if resp.strip().lower() == "y": |
| | break |
| | elif resp.strip().lower() == "n": |
| | sys.exit(0) |
| |
|
| | print("Executing...") |
| | if len(files_to_desymlink) > 0: |
| | for file in files_to_desymlink: |
| | realpath = os.path.realpath(file) |
| | print("rm " + file) |
| | os.remove(file) |
| | print("cp {} {}".format(realpath, file)) |
| | shutil.copyfile(realpath, file) |
| | if len(files_to_delete) > 0: |
| | for file in files_to_delete: |
| | print("rm " + file) |
| | os.remove(file) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|