File size: 2,454 Bytes
ea847ad
 
 
 
 
 
 
 
 
12babad
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12babad
ea847ad
 
 
 
 
 
 
 
 
 
12babad
ea847ad
 
 
 
 
12babad
c6b22f5
ea847ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations
import os
import signal

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment

from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze
from realfake.models import RealFakeParams
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint


def get_existing_checkpoint(job_id: str | None = None) -> tuple:
    if job_id is None:
        checkpoints_dir = get_checkpoints_dir(timestamp=True)
    else:
        checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id
    checkpoints_dir.mkdir(parents=True, exist_ok=True)
    existing_checkpoint = find_latest_checkpoint(checkpoints_dir)
    return checkpoints_dir, existing_checkpoint


def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
    job_id = os.environ.get("SLURM_JOB_ID")
    checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id)

    if job_id is None:
        print("SLURM job id is not found, running locally.")

    if existing_checkpoint is None:
        print("No existing checkpoint found, starting from scratch.")

    if args.accelerator.override_float32_matmul:
        torch.set_float32_matmul_precision(args.accelerator.float32_matmul)

    with (checkpoints_dir/"params.json").open("w") as fp:
        fp.write(args.json())

    trainer_params = dict(
        accelerator=args.accelerator.name, 
        devices=args.accelerator.devices,
        precision=args.accelerator.precision,
        max_epochs=args.epochs,
        num_nodes=1,
        num_sanity_val_steps=0,
        enable_progress_bar=args.progress_bar,
        callbacks=[
            ConsoleLogger(),
            ModelCheckpoint(
                monitor="val_acc", 
                mode="max",
                save_last=True, 
                save_top_k=1, 
                dirpath=checkpoints_dir,
                filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
            ),
            FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs)
        ],
        resume_from_checkpoint=existing_checkpoint,
    )

    if job_id is not None:
        trainer_params["enable_progress_bar"] = False
        trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
        trainer_params["strategy"] = args.accelerator.strategy

    return pl.Trainer(**trainer_params)