NCERL-Diverse-PCG / train.py
baiyanlali-zhao's picture
添加注释
3582c8a
raw
history blame
1.93 kB
import argparse
from src.drl.train_async import *
from src.gan.adversarial_train import *
from src.drl.train_sinproc import set_SAC_parser, train_SAC
from src.drl.egsac.train_egsac import set_EGSAC_parser, train_EGSAC
from src.drl.sunrise.train_sunrise import train_SUNRISE, set_SUNRISE_args
from src.drl.dvd import set_DvDSAC_parser, train_DvDSAC
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
parser_gan = subparsers.add_parser('gan', help='Train GAN')
set_GAN_parser(parser_gan)
parser_gan.set_defaults(entry=train_GAN)
parser_sac = subparsers.add_parser('sac', help='Train SAC')
set_SAC_parser(parser_sac)
parser_sac.set_defaults(entry=train_SAC)
parser_asyncsac = subparsers.add_parser('asyncsac', help='Train Asynchronous SAC')
set_AsyncSAC_parser(parser_asyncsac)
parser_asyncsac.set_defaults(entry=train_AsyncSAC)
parser_egsac = subparsers.add_parser('egsac', help='Train Episodic Generative SAC')
set_EGSAC_parser(parser_egsac)
parser_egsac.set_defaults(entry=train_EGSAC)
parser_ncesac = subparsers.add_parser('ncesac', help='Train Negatively Correlated Ensemble SAC')
set_NCESAC_parser(parser_ncesac)
parser_ncesac.set_defaults(entry=train_NCESAC)
parser_pmoesac = subparsers.add_parser('pmoe', help='Train PMOE')
set_PMOESAC_parser(parser_pmoesac)
parser_pmoesac.set_defaults(entry=train_PMOESAC)
parser_sunrise = subparsers.add_parser('sunrise', help='Train SUNRISE')
set_SUNRISE_args(parser_sunrise)
parser_sunrise.set_defaults(entry=train_SUNRISE)
parser_dvd = subparsers.add_parser('dvd', help='Train DvD')
set_DvDSAC_parser(parser_dvd)
parser_dvd.set_defaults(entry=train_DvDSAC)
args = parser.parse_args()
entry = args.entry
# entry是每一个模型的训练入口,具体函数在各个subparser内定义
entry(args)