Spaces:
Running
Running
File size: 4,841 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import hydra
import wandb
from os.path import isfile, join
from shutil import copyfile
import torch
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from pytorch_lightning.callbacks import LearningRateMonitor
from lightning_fabric.utilities.rank_zero import _get_rank
from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch
from models.module import DiffGeolocalizer
torch.set_float32_matmul_precision("high") # TODO do we need that?
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)
def wandb_init(cfg):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "":
with open(join(directory, "wandb_id.txt"), "r") as f:
wandb_id = f.readline()
else:
rank = _get_rank()
wandb_id = wandb.util.generate_id()
print(f"Generated wandb id: {wandb_id}")
if rank == 0 or rank is None:
with open(join(directory, "wandb_id.txt"), "w") as f:
f.write(str(wandb_id))
return wandb_id
def load_model(cfg, dict_config, wandb_id, callbacks):
directory = cfg.checkpoints.dirpath
if isfile(join(directory, "last.ckpt")):
checkpoint_path = join(directory, "last.ckpt")
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
model = DiffGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model)
ckpt_path = join(directory, "last.ckpt")
print(f"Loading form checkpoint ... {ckpt_path}")
else:
ckpt_path = None
logger = instantiate(cfg.logger, id=wandb_id, resume="allow")
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]}
logger._wandb_init.update({"config": log_dict})
model = DiffGeolocalizer(cfg.model)
trainer, strategy = cfg.trainer, cfg.trainer.strategy
# from pytorch_lightning.profilers import PyTorchProfiler
trainer = instantiate(
trainer,
strategy=strategy,
logger=logger,
callbacks=callbacks,
# profiler=PyTorchProfiler(
# dirpath="logs",
# schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"),
# record_shapes=True,
# with_stack=True,
# with_flops=True,
# with_modules=True,
# ),
)
return trainer, model, ckpt_path
def project_init(cfg):
print("Working directory set to {}".format(os.getcwd()))
directory = cfg.checkpoints.dirpath
os.makedirs(directory, exist_ok=True)
copyfile(".hydra/config.yaml", join(directory, "config.yaml"))
def callback_init(cfg):
checkpoint_callback = instantiate(cfg.checkpoints)
progress_bar = instantiate(cfg.progress_bar)
lr_monitor = LearningRateMonitor()
ema_callback = EMACallback(
"network",
"ema_network",
decay=cfg.model.ema_decay,
start_ema_step=cfg.model.start_ema_step,
init_ema_random=False,
)
fix_nan_callback = FixNANinGrad(
monitor=["train/loss"],
)
increase_data_epoch_callback = IncreaseDataEpoch()
callbacks = [
checkpoint_callback,
progress_bar,
lr_monitor,
ema_callback,
fix_nan_callback,
increase_data_epoch_callback,
]
return callbacks
def init_datamodule(cfg):
datamodule = instantiate(cfg.datamodule)
return datamodule
def hydra_boilerplate(cfg):
dict_config = OmegaConf.to_container(cfg, resolve=True)
callbacks = callback_init(cfg)
datamodule = init_datamodule(cfg)
project_init(cfg)
wandb_id = wandb_init(cfg)
trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks)
return trainer, model, datamodule, ckpt_path
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
if "stage" in cfg and cfg.stage == "debug":
import lovely_tensors as lt
lt.monkey_patch()
trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg)
model.datamodule = datamodule
# model = torch.compile(model)
if cfg.mode == "train":
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
elif cfg.mode == "eval":
trainer.test(model, datamodule=datamodule)
elif cfg.mode == "traineval":
cfg.mode = "train"
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
cfg.mode = "test"
trainer.test(model, datamodule=datamodule)
if __name__ == "__main__":
main()
|