| import os |
| import hydra |
| import wandb |
| from os.path import isfile, join |
| from shutil import copyfile |
|
|
| import torch |
|
|
| from omegaconf import OmegaConf |
| from hydra.core.hydra_config import HydraConfig |
| from hydra.utils import instantiate |
| from pytorch_lightning.callbacks import LearningRateMonitor |
| from lightning_fabric.utilities.rank_zero import _get_rank |
| from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch |
| from models.module import DiffGeolocalizer |
|
|
| torch.set_float32_matmul_precision("high") |
|
|
| |
| |
| |
| OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
| def wandb_init(cfg): |
| directory = cfg.checkpoints.dirpath |
| if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "": |
| with open(join(directory, "wandb_id.txt"), "r") as f: |
| wandb_id = f.readline() |
| else: |
| rank = _get_rank() |
| wandb_id = wandb.util.generate_id() |
| print(f"Generated wandb id: {wandb_id}") |
| if rank == 0 or rank is None: |
| with open(join(directory, "wandb_id.txt"), "w") as f: |
| f.write(str(wandb_id)) |
|
|
| return wandb_id |
|
|
|
|
| def load_model(cfg, dict_config, wandb_id, callbacks): |
| directory = cfg.checkpoints.dirpath |
| if isfile(join(directory, "last.ckpt")): |
| checkpoint_path = join(directory, "last.ckpt") |
| logger = instantiate(cfg.logger, id=wandb_id, resume="allow") |
| model = DiffGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model) |
| ckpt_path = join(directory, "last.ckpt") |
| print(f"Loading form checkpoint ... {ckpt_path}") |
| else: |
| ckpt_path = None |
| logger = instantiate(cfg.logger, id=wandb_id, resume="allow") |
| log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} |
| logger._wandb_init.update({"config": log_dict}) |
| model = DiffGeolocalizer(cfg.model) |
|
|
| trainer, strategy = cfg.trainer, cfg.trainer.strategy |
| |
|
|
| trainer = instantiate( |
| trainer, |
| strategy=strategy, |
| logger=logger, |
| callbacks=callbacks, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ) |
| return trainer, model, ckpt_path |
|
|
|
|
| def project_init(cfg): |
| print("Working directory set to {}".format(os.getcwd())) |
| directory = cfg.checkpoints.dirpath |
| os.makedirs(directory, exist_ok=True) |
| copyfile(".hydra/config.yaml", join(directory, "config.yaml")) |
|
|
|
|
| def callback_init(cfg): |
| checkpoint_callback = instantiate(cfg.checkpoints) |
| progress_bar = instantiate(cfg.progress_bar) |
| lr_monitor = LearningRateMonitor() |
| ema_callback = EMACallback( |
| "network", |
| "ema_network", |
| decay=cfg.model.ema_decay, |
| start_ema_step=cfg.model.start_ema_step, |
| init_ema_random=False, |
| ) |
| fix_nan_callback = FixNANinGrad( |
| monitor=["train/loss"], |
| ) |
| increase_data_epoch_callback = IncreaseDataEpoch() |
| callbacks = [ |
| checkpoint_callback, |
| progress_bar, |
| lr_monitor, |
| ema_callback, |
| fix_nan_callback, |
| increase_data_epoch_callback, |
| ] |
| return callbacks |
|
|
|
|
| def init_datamodule(cfg): |
| datamodule = instantiate(cfg.datamodule) |
| return datamodule |
|
|
|
|
| def hydra_boilerplate(cfg): |
| dict_config = OmegaConf.to_container(cfg, resolve=True) |
| callbacks = callback_init(cfg) |
| datamodule = init_datamodule(cfg) |
| project_init(cfg) |
| wandb_id = wandb_init(cfg) |
| trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) |
| return trainer, model, datamodule, ckpt_path |
|
|
|
|
| @hydra.main(config_path="configs", config_name="config", version_base=None) |
| def main(cfg): |
| if "stage" in cfg and cfg.stage == "debug": |
| import lovely_tensors as lt |
|
|
| lt.monkey_patch() |
| trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) |
| model.datamodule = datamodule |
| |
| if cfg.mode == "train": |
| trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) |
| elif cfg.mode == "eval": |
| trainer.test(model, datamodule=datamodule) |
| elif cfg.mode == "traineval": |
| cfg.mode = "train" |
| trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) |
| cfg.mode = "test" |
| trainer.test(model, datamodule=datamodule) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|