tooncrafter / main /utils_train.py
multimodalart's picture
Upload folder using huggingface_hub
0366b8b verified
raw
history blame
No virus
7.14 kB
import os, re
from omegaconf import OmegaConf
import logging
mainlogger = logging.getLogger('mainlogger')
import torch
from collections import OrderedDict
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
workdir = os.path.join(logdir, name)
ckptdir = os.path.join(workdir, "checkpoints")
cfgdir = os.path.join(workdir, "configs")
loginfo = os.path.join(workdir, "loginfo")
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
os.makedirs(workdir, exist_ok=True)
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
os.makedirs(loginfo, exist_ok=True)
if rank == 0:
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'), exist_ok=True)
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}), os.path.join(cfgdir, "lightning.yaml"))
return workdir, ckptdir, cfgdir, loginfo
def check_config_attribute(config, name):
if name in config:
value = getattr(config, name)
return value
else:
return None
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
default_callbacks_cfg = {
"model_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch}",
"verbose": True,
"save_last": False,
}
},
"batch_logger": {
"target": "callbacks.ImageLogger",
"params": {
"save_dir": logdir,
"batch_frequency": 1000,
"max_images": 4,
"clamp": True,
}
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
"log_momentum": False
}
},
"cuda_callback": {
"target": "callbacks.CUDACallback"
},
}
## optional setting for saving checkpoints
monitor_metric = check_config_attribute(config.model.params, "monitor")
if monitor_metric is not None:
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
mainlogger.info('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch}-{step}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
return callbacks_cfg
def get_trainer_logger(lightning_config, logdir, on_debug):
default_logger_cfgs = {
"tensorboard": {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "tensorboard",
}
},
"testtube": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
}
},
}
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
return logger_cfg
def get_trainer_strategy(lightning_config):
default_strategy_dict = {
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
return strategy_cfg
else:
strategy_cfg = OmegaConf.create()
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
return strategy_cfg
def load_checkpoints(model, model_cfg):
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(pretrained_ckpt), "Error: Pre-trained checkpoint NOT found at:%s"%pretrained_ckpt
mainlogger.info(">>> Load weights from pretrained checkpoint")
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
try:
if 'state_dict' in pl_sd.keys():
model.load_state_dict(pl_sd["state_dict"], strict=True)
mainlogger.info(">>> Loaded weights from pretrained checkpoint: %s"%pretrained_ckpt)
else:
# deepspeed
new_pl_sd = OrderedDict()
for key in pl_sd['module'].keys():
new_pl_sd[key[16:]]=pl_sd['module'][key]
model.load_state_dict(new_pl_sd, strict=True)
except:
model.load_state_dict(pl_sd)
else:
mainlogger.info(">>> Start training from scratch")
return model
def set_logger(logfile, name='mainlogger'):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
ch.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.addHandler(ch)
return logger