Spaces:
Runtime error
Runtime error
File size: 4,922 Bytes
cfb7702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import sys
import argparse
import os
import time
import logging
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument("--gpu", default="0", help="GPU(s) to be used")
parser.add_argument(
"--resume", default=None, help="path to the weights to be resumed"
)
parser.add_argument(
"--resume_weights_only",
action="store_true",
help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--train", action="store_true")
group.add_argument("--validate", action="store_true")
group.add_argument("--test", action="store_true")
group.add_argument("--predict", action="store_true")
# group.add_argument('--export', action='store_true') # TODO: a separate export action
parser.add_argument("--exp_dir", default="./exp")
parser.add_argument("--runs_dir", default="./runs")
parser.add_argument(
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
)
args, extras = parser.parse_known_args()
# set CUDA_VISIBLE_DEVICES then import pytorch-lightning
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
n_gpus = len(args.gpu.split(","))
import datasets
import systems
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from utils.callbacks import (
CodeSnapshotCallback,
ConfigSnapshotCallback,
CustomProgressBar,
)
from utils.misc import load_config
# parse YAML config to OmegaConf
config = load_config(args.config, cli_args=extras)
config.cmd_args = vars(args)
config.trial_name = config.get("trial_name") or (
config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S")
)
config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name)
config.save_dir = config.get("save_dir") or os.path.join(
config.exp_dir, config.trial_name, "save"
)
config.ckpt_dir = config.get("ckpt_dir") or os.path.join(
config.exp_dir, config.trial_name, "ckpt"
)
config.code_dir = config.get("code_dir") or os.path.join(
config.exp_dir, config.trial_name, "code"
)
config.config_dir = config.get("config_dir") or os.path.join(
config.exp_dir, config.trial_name, "config"
)
logger = logging.getLogger("pytorch_lightning")
if args.verbose:
logger.setLevel(logging.DEBUG)
if "seed" not in config:
config.seed = int(time.time() * 1000) % 1000
pl.seed_everything(config.seed)
dm = datasets.make(config.dataset.name, config.dataset)
system = systems.make(
config.system.name,
config,
load_from_checkpoint=None if not args.resume_weights_only else args.resume,
)
callbacks = []
if args.train:
callbacks += [
ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint),
LearningRateMonitor(logging_interval="step"),
# CodeSnapshotCallback(
# config.code_dir, use_version=False
# ),
ConfigSnapshotCallback(config, config.config_dir, use_version=False),
CustomProgressBar(refresh_rate=1),
]
loggers = []
if args.train:
loggers += [
TensorBoardLogger(
args.runs_dir, name=config.name, version=config.trial_name
),
CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"),
]
if sys.platform == "win32":
# does not support multi-gpu on windows
strategy = "dp"
assert n_gpus == 1
else:
strategy = "ddp_find_unused_parameters_false"
trainer = Trainer(
devices=n_gpus,
accelerator="gpu",
callbacks=callbacks,
logger=loggers,
strategy=strategy,
**config.trainer
)
if args.train:
if args.resume and not args.resume_weights_only:
# FIXME: different behavior in pytorch-lighting>1.9 ?
trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
else:
trainer.fit(system, datamodule=dm)
trainer.test(system, datamodule=dm)
elif args.validate:
trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
elif args.test:
trainer.test(system, datamodule=dm, ckpt_path=args.resume)
elif args.predict:
trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
if __name__ == "__main__":
main()
|