File size: 2,454 Bytes
ea847ad 12babad ea847ad 12babad ea847ad 12babad ea847ad 12babad c6b22f5 ea847ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from __future__ import annotations
import os
import signal
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze
from realfake.models import RealFakeParams
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
def get_existing_checkpoint(job_id: str | None = None) -> tuple:
if job_id is None:
checkpoints_dir = get_checkpoints_dir(timestamp=True)
else:
checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id
checkpoints_dir.mkdir(parents=True, exist_ok=True)
existing_checkpoint = find_latest_checkpoint(checkpoints_dir)
return checkpoints_dir, existing_checkpoint
def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
job_id = os.environ.get("SLURM_JOB_ID")
checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id)
if job_id is None:
print("SLURM job id is not found, running locally.")
if existing_checkpoint is None:
print("No existing checkpoint found, starting from scratch.")
if args.accelerator.override_float32_matmul:
torch.set_float32_matmul_precision(args.accelerator.float32_matmul)
with (checkpoints_dir/"params.json").open("w") as fp:
fp.write(args.json())
trainer_params = dict(
accelerator=args.accelerator.name,
devices=args.accelerator.devices,
precision=args.accelerator.precision,
max_epochs=args.epochs,
num_nodes=1,
num_sanity_val_steps=0,
enable_progress_bar=args.progress_bar,
callbacks=[
ConsoleLogger(),
ModelCheckpoint(
monitor="val_acc",
mode="max",
save_last=True,
save_top_k=1,
dirpath=checkpoints_dir,
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
),
FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs)
],
resume_from_checkpoint=existing_checkpoint,
)
if job_id is not None:
trainer_params["enable_progress_bar"] = False
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
trainer_params["strategy"] = args.accelerator.strategy
return pl.Trainer(**trainer_params)
|