| import argparse | |
| import random | |
| import time | |
| import torch | |
| import numpy as np | |
| from network import GNet | |
| from trainer import Trainer | |
| from utils.data_loader import FileLoader | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='Args for graph predition') | |
| parser.add_argument('-seed', type=int, default=1, help='seed') | |
| parser.add_argument('-data', default='DD', help='data folder name') | |
| parser.add_argument('-fold', type=int, default=1, help='fold (1..10)') | |
| parser.add_argument('-num_epochs', type=int, default=2, help='epochs') | |
| parser.add_argument('-batch', type=int, default=8, help='batch size') | |
| parser.add_argument('-lr', type=float, default=0.001, help='learning rate') | |
| parser.add_argument('-deg_as_tag', type=int, default=0, help='1 or degree') | |
| parser.add_argument('-l_num', type=int, default=3, help='layer num') | |
| parser.add_argument('-h_dim', type=int, default=512, help='hidden dim') | |
| parser.add_argument('-l_dim', type=int, default=48, help='layer dim') | |
| parser.add_argument('-drop_n', type=float, default=0.3, help='drop net') | |
| parser.add_argument('-drop_c', type=float, default=0.2, help='drop output') | |
| parser.add_argument('-act_n', type=str, default='ELU', help='network act') | |
| parser.add_argument('-act_c', type=str, default='ELU', help='output act') | |
| parser.add_argument('-ks', nargs='+', type=float, default='0.9 0.8 0.7') | |
| parser.add_argument('-acc_file', type=str, default='re', help='acc file') | |
| args, _ = parser.parse_known_args() | |
| return args | |
| def set_random(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def app_run(args, G_data, fold_idx): | |
| G_data.use_fold_data(fold_idx) | |
| net = GNet(G_data.feat_dim, G_data.num_class, args) | |
| trainer = Trainer(args, net, G_data) | |
| trainer.train() | |
| def main(): | |
| args = get_args() | |
| print(args) | |
| set_random(args.seed) | |
| start = time.time() | |
| G_data = FileLoader(args).load_data() | |
| print('load data using ------>', time.time()-start) | |
| if args.fold == 0: | |
| for fold_idx in range(10): | |
| print('start training ------> fold', fold_idx+1) | |
| app_run(args, G_data, fold_idx) | |
| else: | |
| print('start training ------> fold', args.fold) | |
| app_run(args, G_data, args.fold-1) | |
| if __name__ == "__main__": | |
| main() | |