Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,977 Bytes
b5ce381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
import os
from omegaconf import OmegaConf
from pytorch_lightning.utilities import rank_zero_only
MULTINODE_HACKS = True
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
config,
lightning_config,
debug,
ckpt_name=None,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
self.debug = debug
self.ckpt_name = ckpt_name
@rank_zero_only
def on_exception(self, trainer: pl.Trainer, pl_module, exception):
print("Exception occurred: {}".format(exception))
if not self.debug and trainer.global_rank == 0:
print("Summoning checkpoint.")
if self.ckpt_name is None:
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
else:
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
trainer.save_checkpoint(ckpt_path)
@rank_zero_only
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if "callbacks" in self.lightning_config:
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
os.makedirs(
os.path.join(self.ckptdir, "trainstep_checkpoints"),
exist_ok=True,
)
print("Project config")
print(OmegaConf.to_yaml(self.config))
if MULTINODE_HACKS:
import time
time.sleep(5)
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
|