Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import sys | |
sys.path.insert(0, sys.path[0]+'/..') | |
from mmengine.config import Config, DictAction | |
from mmengine.logging import print_log | |
from mmengine.runner import Runner | |
from mmpl.engine.runner import PLRunner | |
import os.path as osp | |
from mmpl.registry import RUNNERS | |
from mmpl.utils import register_all_modules | |
register_all_modules() | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a pl model') | |
parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py', | |
help='train config file path') | |
parser.add_argument('--status', default='predict', help='fit or test', choices=['fit', 'test', 'predict', 'validate']) | |
parser.add_argument('--ckpt-path', | |
default='pretrain/whu/last.ckpt', | |
help='checkpoint path') | |
parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
cfg = Config.fromfile(args.config) | |
if args.work_dir is not None: | |
cfg.trainer_cfg['default_root_dir'] = args.work_dir | |
elif cfg.trainer_cfg.get('default_root_dir', None) is None: | |
# use config filename as default work_dir if cfg.work_dir is None | |
cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) | |
cfg.trainer_cfg['logger'] = None | |
if 'runner_type' not in cfg: | |
runner = PLRunner.from_cfg(cfg) | |
else: | |
runner = RUNNERS.build(cfg) | |
runner.run(args.status, ckpt_path=args.ckpt_path) | |
if __name__ == '__main__': | |
main() | |