import argparse import datetime import glob import inspect import os import sys from inspect import Parameter from typing import Union import numpy as np import pytorch_lightning as pl import torch import torchvision import wandb from PIL import Image from matplotlib import pyplot as plt from natsort import natsorted from omegaconf import OmegaConf from packaging import version from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_only from sgm.util import ( exists, instantiate_from_config, isheatmap, ) MULTINODE_HACKS = True def default_trainer_args(): argspec = dict(inspect.signature(Trainer.__init__).parameters) argspec.pop("self") default_args = { param: argspec[param].default for param in argspec if argspec[param] != Parameter.empty } return default_args def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-n", "--name", type=str, const=True, default="", nargs="?", help="postfix for logdir", ) parser.add_argument( "--no_date", type=str2bool, nargs="?", const=True, default=False, help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-t", "--train", type=str2bool, const=True, default=True, nargs="?", help="train", ) parser.add_argument( "--no-test", type=str2bool, const=True, default=False, nargs="?", help="disable test", ) parser.add_argument( "-p", "--project", help="name of new or path to existing project" ) parser.add_argument( "-d", "--debug", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-s", "--seed", type=int, default=23, help="seed for seed_everything", ) parser.add_argument( "-f", "--postfix", type=str, default="", help="post-postfix for default name", ) parser.add_argument( "--projectname", type=str, default="stablediffusion", ) parser.add_argument( "-l", "--logdir", type=str, default="logs", help="directory for logging dat shit", ) parser.add_argument( "--scale_lr", type=str2bool, nargs="?", const=True, default=False, help="scale base-lr by ngpu * batch_size * n_accumulate", ) parser.add_argument( "--legacy_naming", type=str2bool, nargs="?", const=True, default=False, help="name run based on config file name if true, else by whole path", ) parser.add_argument( "--enable_tf32", type=str2bool, nargs="?", const=True, default=False, help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12", ) parser.add_argument( "--startup", type=str, default=None, help="Startuptime from distributed script", ) parser.add_argument( "--wandb", type=str2bool, nargs="?", const=True, default=False, # TODO: later default to True help="log to wandb", ) parser.add_argument( "--no_base_name", type=str2bool, nargs="?", const=True, default=False, # TODO: later default to True help="log to wandb", ) if version.parse(torch.__version__) >= version.parse("2.0.0"): parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="single checkpoint file to resume from", ) default_args = default_trainer_args() for key in default_args: parser.add_argument("--" + key, default=default_args[key]) return parser def get_checkpoint_name(logdir): ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt") ckpt = natsorted(glob.glob(ckpt)) print('available "last" checkpoints:') print(ckpt) if len(ckpt) > 1: print("got most recent checkpoint") ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1] print(f"Most recent ckpt is {ckpt}") with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f: f.write(ckpt + "\n") try: version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0]) except Exception as e: print("version confusion but not bad") print(e) version = 1 # version = last_version + 1 else: # in this case, we only have one "last.ckpt" ckpt = ckpt[0] version = 1 melk_ckpt_name = f"last-v{version}.ckpt" print(f"Current melk ckpt name: {melk_ckpt_name}") return ckpt, melk_ckpt_name class SetupCallback(Callback): def __init__( self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config, debug, ckpt_name=None, ): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config self.debug = debug self.ckpt_name = ckpt_name def on_exception(self, trainer: pl.Trainer, pl_module, exception): if not self.debug and trainer.global_rank == 0: print("Summoning checkpoint.") if self.ckpt_name is None: ckpt_path = os.path.join(self.ckptdir, "last.ckpt") else: ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) trainer.save_checkpoint(ckpt_path) def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) if "callbacks" in self.lightning_config: if ( "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"] ): os.makedirs( os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True, ) print("Project config") print(OmegaConf.to_yaml(self.config)) if MULTINODE_HACKS: import time time.sleep(5) OmegaConf.save( self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), ) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), ) else: # ModelCheckpoint callback created log directory --- remove it if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass class ImageLogger(Callback): def __init__( self, batch_frequency, max_images, clamp=True, increase_log_steps=True, rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, log_images_kwargs=None, log_before_first_step=False, enable_autocast=True, ): super().__init__() self.enable_autocast = enable_autocast self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp self.disabled = disabled self.log_on_batch_idx = log_on_batch_idx self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step self.log_before_first_step = log_before_first_step @rank_zero_only def log_local( self, save_dir, split, images, global_step, current_epoch, batch_idx, pl_module: Union[None, pl.LightningModule] = None, ): root = os.path.join(save_dir, "images", split) for k in images: if isheatmap(images[k]): fig, ax = plt.subplots() ax = ax.matshow( images[k].cpu().numpy(), cmap="hot", interpolation="lanczos" ) plt.colorbar(ax) plt.axis("off") filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx ) os.makedirs(root, exist_ok=True) path = os.path.join(root, filename) plt.savefig(path) plt.close() # TODO: support wandb else: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx ) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) img = Image.fromarray(grid) img.save(path) if exists(pl_module): assert isinstance( pl_module.logger, WandbLogger ), "logger_log_image only supports WandbLogger currently" pl_module.logger.log_image( key=f"{split}/{k}", images=[ img, ], step=pl_module.global_step, ) @rank_zero_only def log_img(self, pl_module, batch, batch_idx, split="train"): check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step if ( self.check_frequency(check_idx) and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 and callable(pl_module.log_images) and # batch_idx > 5 and self.max_images > 0 ): logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() gpu_autocast_kwargs = { "enabled": self.enable_autocast, # torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): images = pl_module.log_images( batch, split=split, **self.log_images_kwargs ) for k in images: N = min(images[k].shape[0], self.max_images) if not isheatmap(images[k]): images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().float().cpu() if self.clamp and not isheatmap(images[k]): images[k] = torch.clamp(images[k], -1.0, 1.0) self.log_local( pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx, pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None, ) if is_train: pl_module.train() def check_frequency(self, check_idx): if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( check_idx > 0 or self.log_first_step ): try: self.log_steps.pop(0) except IndexError as e: print(e) pass return True return False @rank_zero_only def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") @rank_zero_only def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if self.log_before_first_step and pl_module.global_step == 0: print(f"{self.__class__.__name__}: logging before training") self.log_img(pl_module, batch, batch_idx, split="train") @rank_zero_only def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") if hasattr(pl_module, "calibrate_grad_norm"): if ( pl_module.calibrate_grad_norm and batch_idx % 25 == 0 ) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @rank_zero_only def init_wandb(save_dir, opt, config, group_name, name_str): print(f"setting WANDB_DIR to {save_dir}") os.makedirs(save_dir, exist_ok=True) os.environ["WANDB_DIR"] = save_dir if opt.debug: wandb.init(project=opt.projectname, mode="offline", group=group_name) else: wandb.init( project=opt.projectname, config=config, settings=wandb.Settings(code_dir="./sgm"), group=group_name, name=name_str, ) if __name__ == "__main__": # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. # `nested.key=value` arguments are interpreted as config parameters. # configs are merged from left-to-right followed by command line parameters. # model: # base_learning_rate: float # target: path to lightning module # params: # key: value # data: # target: main.DataModuleFromConfig # params: # batch_size: int # wrap: bool # train: # target: path to train dataset # params: # key: value # validation: # target: path to validation dataset # params: # key: value # test: # target: path to test dataset # params: # key: value # lightning: (optional, has sane defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: # logger to instantiate # modelcheckpoint: # modelcheckpoint to instantiate # callbacks: # callback1: # target: importpath # params: # key: value now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` # (in particular `main.DataModuleFromConfig`) sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( "-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint" ) melk_ckpt_name = None name = None if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) logdir = "/".join(paths[:-2]) ckpt = opt.resume _, melk_ckpt_name = get_checkpoint_name(logdir) else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt, melk_ckpt_name = get_checkpoint_name(logdir) print("#" * 100) print(f'Resuming from checkpoint "{ckpt}"') print("#" * 100) opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base _tmp = logdir.split("/") nowname = _tmp[-1] else: if opt.name: name = "_" + opt.name elif opt.base: if opt.no_base_name: name = "" else: if opt.legacy_naming: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] else: assert "configs" in os.path.split(opt.base[0])[0], os.path.split( opt.base[0] )[0] cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[ os.path.split(opt.base[0])[0].split(os.sep).index("configs") + 1 : ] # cut away the first one (we assert all configs are in "configs") cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0] cfg_name = "-".join(cfg_path) + f"-{cfg_name}" name = "_" + cfg_name else: name = "" if not opt.no_date: nowname = now + name + opt.postfix else: nowname = name + opt.postfix if nowname.startswith("_"): nowname = nowname[1:] logdir = os.path.join(opt.logdir, nowname) print(f"LOGDIR: {logdir}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed, workers=True) # move before model init, in case a torch.compile(...) is called somewhere if opt.enable_tf32: # pt_version = version.parse(torch.__version__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"Enabling TF32 for PyTorch {torch.__version__}") else: print(f"Using default TF32 settings for PyTorch {torch.__version__}:") print( f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}" ) print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}") try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to gpu trainer_config["accelerator"] = "gpu" # standard_args = default_trainer_args() for k in standard_args: if getattr(opt, k) != standard_args[k]: trainer_config[k] = getattr(opt, k) ckpt_resume_path = opt.resume_from_checkpoint if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu": del trainer_config["accelerator"] cpu = True else: gpuinfo = trainer_config["devices"] print(f"Running on GPUs {gpuinfo}") cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # model model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() # default logger configs default_logger_cfgs = { "wandb": { "target": "pytorch_lightning.loggers.WandbLogger", "params": { "name": nowname, # "save_dir": logdir, "offline": opt.debug, "id": nowname, "project": opt.projectname, "log_model": False, # "dir": logdir, }, }, "csv": { "target": "pytorch_lightning.loggers.CSVLogger", "params": { "name": "testtube", # hack for sbord fanatics "save_dir": logdir, }, }, } default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"] if opt.wandb: # TODO change once leaving "swiffer" config directory try: group_name = nowname.split(now)[-1].split("-")[1] except: group_name = nowname default_logger_cfg["params"]["group"] = group_name init_wandb( os.path.join(os.getcwd(), logdir), opt=opt, group_name=group_name, config=config, name_str=nowname, ) if "logger" in lightning_config: logger_cfg = lightning_config.logger else: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, }, } if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["save_top_k"] = 3 if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html # default to ddp if not further specified default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"} if "strategy" in lightning_config: strategy_cfg = lightning_config.strategy else: strategy_cfg = OmegaConf.create() default_strategy_config["params"] = { "find_unused_parameters": False, # "static_graph": True, # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded } strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg) print( f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ " ) trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "main.SetupCallback", "params": { "resume": opt.resume, "now": now, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, "lightning_config": lightning_config, "debug": opt.debug, "ckpt_name": melk_ckpt_name, }, }, "image_logger": { "target": "main.ImageLogger", "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True}, }, "learning_rate_logger": { "target": "pytorch_lightning.callbacks.LearningRateMonitor", "params": { "logging_interval": "step", # "log_momentum": True }, }, } if version.parse(pl.__version__) >= version.parse("1.4.0"): default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: print( "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." ) default_metrics_over_trainsteps_ckpt_dict = { "metrics_over_trainsteps_checkpoint": { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), "filename": "{epoch:06}-{step:09}", "verbose": True, "save_top_k": -1, "every_n_train_steps": 10000, "save_weights_only": True, }, } } default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None: callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path elif "ignore_keys_callback" in callbacks_cfg: del callbacks_cfg["ignore_keys_callback"] trainer_kwargs["callbacks"] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg ] if not "plugins" in trainer_kwargs: trainer_kwargs["plugins"] = list() # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs) trainer_opt = vars(trainer_opt) trainer_kwargs = { key: val for key, val in trainer_kwargs.items() if key not in trainer_opt } trainer = Trainer(**trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### # data data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() # data.setup() print("#### Data #####") try: for k in data.datasets: print( f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" ) except: print("datasets not yet initialized.") # configure learning rate if "batch_size" in config.data.params: bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate else: bs, base_lr = ( config.data.params.train.loader.batch_size, config.model.base_learning_rate, ) if not cpu: ngpu = len(lightning_config.trainer.devices.strip(",").split(",")) else: ngpu = 1 if "accumulate_grad_batches" in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr ) ) else: model.learning_rate = base_lr print("++++ NOT USING LR SCALING ++++") print(f"Setting learning rate to {model.learning_rate:.2e}") # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: print("Summoning checkpoint.") if melk_ckpt_name is None: ckpt_path = os.path.join(ckptdir, "last.ckpt") else: ckpt_path = os.path.join(ckptdir, melk_ckpt_name) trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb pudb.set_trace() import signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) # run if opt.train: try: trainer.fit(model, data, ckpt_path=ckpt_resume_path) except Exception: if not opt.debug: melk() raise if not opt.no_test and not trainer.interrupted: trainer.test(model, data) except RuntimeError as err: if MULTINODE_HACKS: import requests import datetime import os import socket device = os.environ.get("CUDA_VISIBLE_DEVICES", "?") hostname = socket.gethostname() ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id") print( f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}", flush=True, ) raise err except Exception: if opt.debug and trainer.global_rank == 0: try: import pudb as debugger except ImportError: import pdb as debugger debugger.post_mortem() raise finally: # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) if opt.wandb: wandb.finish() # if trainer.global_rank == 0: # print(trainer.profiler.summary())