import argparse import train_ode import train_resnet import train_cnf def main(args): if args.model == 'odenet': train_ode.train_and_evaluate(args.lr, args.n_epoch, args.batch_size, args.tol) elif args.model == 'resnet': train_resnet.train_and_evaluate(args.lr, args.n_epoch, args.batch_size) elif args.model == 'cnf': train_cnf.train(0.001, 1000, 512, 2, 32, 64, 0., 10., args.viz, args.sample_dataset) if __name__ == '__main__': parser = argparse.ArgumentParser(description='main.py') parser.add_argument("--model", type=str, choices=['odenet', 'resnet', 'cnf'], default="odenet", help="Type of model") parser.add_argument("--tol", type=float, default=1e-1, help="Error tolerance for ODE solver. This only works with odenet") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch") parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch") parser.add_argument("--sample_dataset", type=str, choices=['circles', 'moons'], default="circles", help="Sample dataset") parser.add_argument("--viz", action='store_true') args = parser.parse_args() main(args)