xzxzxiaoo's picture
Upload folder using huggingface_hub
94e8ee8 verified
raw
history blame contribute delete
No virus
4.81 kB
import sys
import argparse
import os
import time
import logging
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', required=True, help='path to config file')
parser.add_argument('--gpu', default='0', help='GPU(s) to be used')
parser.add_argument('--resume', default=None, help='path to the weights to be resumed')
parser.add_argument(
'--resume_weights_only',
action='store_true',
help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only'
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--train', action='store_true')
group.add_argument('--validate', action='store_true')
group.add_argument('--test', action='store_true')
group.add_argument('--predict', action='store_true')
# group.add_argument('--export', action='store_true') # TODO: a separate export action
parser.add_argument('--exp_dir', default='./exp')
parser.add_argument('--runs_dir', default='./runs')
parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG')
args, extras = parser.parse_known_args()
# set CUDA_VISIBLE_DEVICES then import pytorch-lightning
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
n_gpus = len(args.gpu.split(','))
import datasets
import systems
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar
from utils.misc import load_config
# parse YAML config to OmegaConf
config = load_config(args.config, cli_args=extras)
config.cmd_args = vars(args)
config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S'))
config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name)
config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save')
config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt')
config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code')
config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config')
logger = logging.getLogger('pytorch_lightning')
if args.verbose:
logger.setLevel(logging.DEBUG)
if 'seed' not in config:
config.seed = int(time.time() * 1000) % 1000
pl.seed_everything(config.seed)
dm = datasets.make(config.dataset.name, config.dataset)
system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume)
callbacks = []
if args.train:
callbacks += [
ModelCheckpoint(
dirpath=config.ckpt_dir,
**config.checkpoint
),
LearningRateMonitor(logging_interval='step'),
CodeSnapshotCallback(
config.code_dir, use_version=False
),
ConfigSnapshotCallback(
config, config.config_dir, use_version=False
),
CustomProgressBar(refresh_rate=1),
]
loggers = []
if args.train:
loggers += [
TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name),
CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs')
]
if sys.platform == 'win32':
# does not support multi-gpu on windows
strategy = 'dp'
assert n_gpus == 1
else:
strategy = 'ddp_find_unused_parameters_false'
trainer = Trainer(
devices=n_gpus,
accelerator='gpu',
callbacks=callbacks,
logger=loggers,
strategy=strategy,
**config.trainer
)
if args.train:
if args.resume and not args.resume_weights_only:
# FIXME: different behavior in pytorch-lighting>1.9 ?
trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
else:
trainer.fit(system, datamodule=dm)
trainer.test(system, datamodule=dm)
elif args.validate:
trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
elif args.test:
trainer.test(system, datamodule=dm, ckpt_path=args.resume)
elif args.predict:
trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
if __name__ == '__main__':
main()