realfake / realfake /train.py
devforfu
Fine-tuning support
12babad
from __future__ import annotations
import os
import signal
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze
from realfake.models import RealFakeParams
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
def get_existing_checkpoint(job_id: str | None = None) -> tuple:
if job_id is None:
checkpoints_dir = get_checkpoints_dir(timestamp=True)
else:
checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id
checkpoints_dir.mkdir(parents=True, exist_ok=True)
existing_checkpoint = find_latest_checkpoint(checkpoints_dir)
return checkpoints_dir, existing_checkpoint
def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
job_id = os.environ.get("SLURM_JOB_ID")
checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id)
if job_id is None:
print("SLURM job id is not found, running locally.")
if existing_checkpoint is None:
print("No existing checkpoint found, starting from scratch.")
if args.accelerator.override_float32_matmul:
torch.set_float32_matmul_precision(args.accelerator.float32_matmul)
with (checkpoints_dir/"params.json").open("w") as fp:
fp.write(args.json())
trainer_params = dict(
accelerator=args.accelerator.name,
devices=args.accelerator.devices,
precision=args.accelerator.precision,
max_epochs=args.epochs,
num_nodes=1,
num_sanity_val_steps=0,
enable_progress_bar=args.progress_bar,
callbacks=[
ConsoleLogger(),
ModelCheckpoint(
monitor="val_acc",
mode="max",
save_last=True,
save_top_k=1,
dirpath=checkpoints_dir,
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
),
FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs)
],
resume_from_checkpoint=existing_checkpoint,
)
if job_id is not None:
trainer_params["enable_progress_bar"] = False
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
trainer_params["strategy"] = args.accelerator.strategy
return pl.Trainer(**trainer_params)