File size: 1,344 Bytes
5775f48
 
 
0db4000
5775f48
 
 
 
 
0db4000
5775f48
0db4000
 
5775f48
 
 
 
0db4000
 
5775f48
 
 
 
 
2e90bae
0db4000
 
5775f48
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
import train_ode
import train_resnet
import train_cnf


def main(args):
    if args.model == 'odenet':
        train_ode.train_and_evaluate(args.lr, args.n_epoch, args.batch_size, args.tol)
    elif args.model == 'resnet':
        train_resnet.train_and_evaluate(args.lr, args.n_epoch, args.batch_size)
    elif args.model == 'cnf':
        train_cnf.train(0.001, 1000, 512, 2, 32, 64, 0., 10., args.viz, args.sample_dataset)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='main.py')
    parser.add_argument("--model", type=str, choices=['odenet', 'resnet', 'cnf'], default="odenet",
                        help="Type of model")
    parser.add_argument("--tol", type=float, default=1e-1,
                        help="Error tolerance for ODE solver. This only works with odenet")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--n_epoch", type=int, default=10, help="Total number of epoch")
    parser.add_argument("--batch_size", type=int, default=32, help="Number of images in batch")
    parser.add_argument("--sample_dataset", type=str, choices=['circles', 'moons'], default="circles",
                        help="Sample dataset")
    parser.add_argument("--viz", action='store_true')

    args = parser.parse_args()
    main(args)