import sys import argparse import os import time import logging from datetime import datetime def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to config file") parser.add_argument("--gpu", default="0", help="GPU(s) to be used") parser.add_argument( "--resume", default=None, help="path to the weights to be resumed" ) parser.add_argument( "--resume_weights_only", action="store_true", help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only", ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--train", action="store_true") group.add_argument("--validate", action="store_true") group.add_argument("--test", action="store_true") group.add_argument("--predict", action="store_true") # group.add_argument('--export', action='store_true') # TODO: a separate export action parser.add_argument("--exp_dir", default="./exp") parser.add_argument("--runs_dir", default="./runs") parser.add_argument( "--verbose", action="store_true", help="if true, set logging level to DEBUG" ) args, extras = parser.parse_known_args() # set CUDA_VISIBLE_DEVICES then import pytorch-lightning os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu n_gpus = len(args.gpu.split(",")) import datasets import systems import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from utils.callbacks import ( CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar, ) from utils.misc import load_config # parse YAML config to OmegaConf config = load_config(args.config, cli_args=extras) config.cmd_args = vars(args) config.trial_name = config.get("trial_name") or ( config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S") ) config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name) config.save_dir = config.get("save_dir") or os.path.join( config.exp_dir, config.trial_name, "save" ) config.ckpt_dir = config.get("ckpt_dir") or os.path.join( config.exp_dir, config.trial_name, "ckpt" ) config.code_dir = config.get("code_dir") or os.path.join( config.exp_dir, config.trial_name, "code" ) config.config_dir = config.get("config_dir") or os.path.join( config.exp_dir, config.trial_name, "config" ) logger = logging.getLogger("pytorch_lightning") if args.verbose: logger.setLevel(logging.DEBUG) if "seed" not in config: config.seed = int(time.time() * 1000) % 1000 pl.seed_everything(config.seed) dm = datasets.make(config.dataset.name, config.dataset) system = systems.make( config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume, ) callbacks = [] if args.train: callbacks += [ ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint), LearningRateMonitor(logging_interval="step"), # CodeSnapshotCallback( # config.code_dir, use_version=False # ), ConfigSnapshotCallback(config, config.config_dir, use_version=False), CustomProgressBar(refresh_rate=1), ] loggers = [] if args.train: loggers += [ TensorBoardLogger( args.runs_dir, name=config.name, version=config.trial_name ), CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"), ] if sys.platform == "win32": # does not support multi-gpu on windows strategy = "dp" assert n_gpus == 1 else: strategy = "ddp_find_unused_parameters_false" trainer = Trainer( devices=n_gpus, accelerator="gpu", callbacks=callbacks, logger=loggers, strategy=strategy, **config.trainer ) if args.train: if args.resume and not args.resume_weights_only: # FIXME: different behavior in pytorch-lighting>1.9 ? trainer.fit(system, datamodule=dm, ckpt_path=args.resume) else: trainer.fit(system, datamodule=dm) trainer.test(system, datamodule=dm) elif args.validate: trainer.validate(system, datamodule=dm, ckpt_path=args.resume) elif args.test: trainer.test(system, datamodule=dm, ckpt_path=args.resume) elif args.predict: trainer.predict(system, datamodule=dm, ckpt_path=args.resume) if __name__ == "__main__": main()