|
from __future__ import annotations |
|
import os |
|
import signal |
|
|
|
import torch |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.plugins.environments import SLURMEnvironment |
|
|
|
from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze |
|
from realfake.models import RealFakeParams |
|
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint |
|
|
|
|
|
def get_existing_checkpoint(job_id: str | None = None) -> tuple: |
|
if job_id is None: |
|
checkpoints_dir = get_checkpoints_dir(timestamp=True) |
|
else: |
|
checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id |
|
checkpoints_dir.mkdir(parents=True, exist_ok=True) |
|
existing_checkpoint = find_latest_checkpoint(checkpoints_dir) |
|
return checkpoints_dir, existing_checkpoint |
|
|
|
|
|
def prepare_trainer(args: RealFakeParams) -> pl.Trainer: |
|
job_id = os.environ.get("SLURM_JOB_ID") |
|
checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id) |
|
|
|
if job_id is None: |
|
print("SLURM job id is not found, running locally.") |
|
|
|
if existing_checkpoint is None: |
|
print("No existing checkpoint found, starting from scratch.") |
|
|
|
if args.accelerator.override_float32_matmul: |
|
torch.set_float32_matmul_precision(args.accelerator.float32_matmul) |
|
|
|
with (checkpoints_dir/"params.json").open("w") as fp: |
|
fp.write(args.json()) |
|
|
|
trainer_params = dict( |
|
accelerator=args.accelerator.name, |
|
devices=args.accelerator.devices, |
|
precision=args.accelerator.precision, |
|
max_epochs=args.epochs, |
|
num_nodes=1, |
|
num_sanity_val_steps=0, |
|
enable_progress_bar=args.progress_bar, |
|
callbacks=[ |
|
ConsoleLogger(), |
|
ModelCheckpoint( |
|
monitor="val_acc", |
|
mode="max", |
|
save_last=True, |
|
save_top_k=1, |
|
dirpath=checkpoints_dir, |
|
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name, |
|
), |
|
FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs) |
|
], |
|
resume_from_checkpoint=existing_checkpoint, |
|
) |
|
|
|
if job_id is not None: |
|
trainer_params["enable_progress_bar"] = False |
|
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP) |
|
trainer_params["strategy"] = args.accelerator.strategy |
|
|
|
return pl.Trainer(**trainer_params) |
|
|