from __future__ import annotations import os import signal import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze from realfake.models import RealFakeParams from realfake.utils import get_checkpoints_dir, find_latest_checkpoint def get_existing_checkpoint(job_id: str | None = None) -> tuple: if job_id is None: checkpoints_dir = get_checkpoints_dir(timestamp=True) else: checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id checkpoints_dir.mkdir(parents=True, exist_ok=True) existing_checkpoint = find_latest_checkpoint(checkpoints_dir) return checkpoints_dir, existing_checkpoint def prepare_trainer(args: RealFakeParams) -> pl.Trainer: job_id = os.environ.get("SLURM_JOB_ID") checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id) if job_id is None: print("SLURM job id is not found, running locally.") if existing_checkpoint is None: print("No existing checkpoint found, starting from scratch.") if args.accelerator.override_float32_matmul: torch.set_float32_matmul_precision(args.accelerator.float32_matmul) with (checkpoints_dir/"params.json").open("w") as fp: fp.write(args.json()) trainer_params = dict( accelerator=args.accelerator.name, devices=args.accelerator.devices, precision=args.accelerator.precision, max_epochs=args.epochs, num_nodes=1, num_sanity_val_steps=0, enable_progress_bar=args.progress_bar, callbacks=[ ConsoleLogger(), ModelCheckpoint( monitor="val_acc", mode="max", save_last=True, save_top_k=1, dirpath=checkpoints_dir, filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name, ), FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs) ], resume_from_checkpoint=existing_checkpoint, ) if job_id is not None: trainer_params["enable_progress_bar"] = False trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP) trainer_params["strategy"] = args.accelerator.strategy return pl.Trainer(**trainer_params)