|
import argparse |
|
import os |
|
import random |
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
import time |
|
import torch.optim as optim |
|
|
|
from matplotlib import cm |
|
import matplotlib.pyplot as plt |
|
import json |
|
from model import GFusion |
|
import torch.nn.functional as F |
|
from torch_geometric.data import Data |
|
from torch_geometric.loader import DataLoader |
|
|
|
from torch_geometric.utils import add_self_loops |
|
from torch.nn.functional import softmax |
|
from torch_geometric.nn import knn_graph |
|
import copy |
|
|
|
torch.autograd.set_detect_anomaly(True) |
|
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc |
|
from sklearn.feature_selection import r_regression |
|
import pickle |
|
from utils.utils import triplets,unique,pos2key |
|
from torch.utils.tensorboard import SummaryWriter |
|
from datetime import datetime |
|
import dataset |
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
blue = lambda x: '\033[94m' + x + '\033[0m' |
|
red = lambda x: '\033[31m' + x + '\033[0m' |
|
green = lambda x: '\033[32m' + x + '\033[0m' |
|
yellow = lambda x: '\033[33m' + x + '\033[0m' |
|
greenline = lambda x: '\033[42m' + x + '\033[0m' |
|
yellowline = lambda x: '\033[43m' + x + '\033[0m' |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--log', type=str, default="True") |
|
parser.add_argument('--loadmodel', type=str, default="False") |
|
parser.add_argument('--split_dataset', type=str, default="False") |
|
parser.add_argument('--model', type=str, default="GFusion") |
|
|
|
|
|
parser.add_argument('--edge_rep', type=str, default="True") |
|
parser.add_argument('--single_high', type=str, default="False") |
|
parser.add_argument('--fidelity_train', type=str, default="True") |
|
parser.add_argument('--fidelity_low_weight', type=float, default=-1.0) |
|
parser.add_argument('--share', type=str, default="101") |
|
|
|
parser.add_argument('--dataset', type=str, default='flu') |
|
parser.add_argument('--manualSeed', type=str, default="False") |
|
parser.add_argument('--man_seed', type=int, default=12345) |
|
parser.add_argument('--test_per_round', type=int, default=10) |
|
parser.add_argument('--patience', type=int, default=30) |
|
parser.add_argument('--nepoch', type=int, default=201) |
|
parser.add_argument('--lr', type=float, default=1e-3) |
|
parser.add_argument('--activation', type=str, default='relu') |
|
parser.add_argument('--batchSize', type=int, default=512) |
|
|
|
parser.add_argument('--num_neighbors', type=int, default=3) |
|
parser.add_argument('--regression_loss', type=str, default='l2') |
|
|
|
parser.add_argument('--h_ch', type=int, default=16) |
|
parser.add_argument('--localdepth', type=int, default=1) |
|
parser.add_argument('--num_interactions', type=int, default=1) |
|
parser.add_argument('--finaldepth', type=int, default=3) |
|
|
|
args = parser.parse_args() |
|
args.log=True if args.log=="True" else False |
|
args.loadmodel=True if args.loadmodel=="True" else False |
|
args.split_dataset=True if args.split_dataset=="True" else False |
|
args.edge_rep=True if args.edge_rep=="True" else False |
|
args.single_high=True if args.single_high=="True" else False |
|
args.fidelity_train=True if args.fidelity_train=="True" and args.single_high is False and args.fidelity_low_weight==-1.0 else False |
|
args.manualSeed=True if args.manualSeed=="True" else False |
|
args.save_dir=os.path.join('./save/',args.dataset) |
|
return args |
|
|
|
def main(args,train_Loader,val_Loader,test_Loader): |
|
if flag: |
|
return |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
measure_Pearsonr=r_regression |
|
criterion_l1 = torch.nn.L1Loss() |
|
criterion_l2 = torch.nn.MSELoss() |
|
criterion=criterion_l1 if args.regression_loss=='l1' else criterion_l2 |
|
if args.model in ['GFusion']: |
|
def myL1(pred,true,weight=None,reduction='mean'): |
|
loss=(abs(pred-true)) |
|
num=len(pred) |
|
if weight is not None: |
|
loss=[weight[i]*loss[i] for i in range(num)] |
|
loss=sum(loss) |
|
if reduction=='mean': |
|
loss=loss/num |
|
return loss |
|
def myL2(pred,true,weight=None,reduction='mean'): |
|
loss=((pred-true)**2) |
|
num=len(pred) |
|
if weight is not None: |
|
loss=[weight[i]*loss[i] for i in range(num)] |
|
loss=sum(loss) |
|
if reduction=='mean': |
|
loss=loss/num |
|
return loss |
|
criterion=myL1 if args.regression_loss=='l1' else myL2 |
|
num_of_fidelities=len(train_graphs[0]) |
|
|
|
def reweight_fidelity(): |
|
if args.single_high: |
|
weighted_fidelity_weight[0]=1 |
|
weighted_fidelity_weight[1]=0 |
|
elif args.fidelity_low_weight!=-1.0: |
|
weighted_fidelity_weight[0]=1 |
|
weighted_fidelity_weight[1]=args.fidelity_low_weight |
|
else: |
|
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)] |
|
fsum=sum(exped_f) |
|
for i in range(num_of_fidelities): |
|
weighted_fidelity_weight[i]=exped_f[i]/fsum |
|
fidelity_weight,weighted_fidelity_weight=[],[] |
|
if args.dataset in ['south',"north","flu"]: |
|
for i in range(num_of_fidelities): |
|
fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()] |
|
weighted_fidelity_weight+=[0] |
|
elif args.dataset in ["syn"]: |
|
fidelity_weight=[torch.tensor(1,dtype=torch.float32).requires_grad_(),torch.tensor(0.0,dtype=torch.float32).requires_grad_()] |
|
for i in range(num_of_fidelities): |
|
|
|
weighted_fidelity_weight+=[0] |
|
reweight_fidelity() |
|
if args.dataset in ['south',"north"]: |
|
x_in=30 |
|
elif args.dataset in ['flu']: |
|
x_in=0 |
|
elif args.dataset=='syn': |
|
x_in=1 |
|
else: |
|
raise Exception('Dataset not recognized.') |
|
if args.model=="GFusion": |
|
GFusion_model=GFusion(h_channel=args.h_ch,input_featuresize=x_in,\ |
|
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share) |
|
GFusion_model.to(device) |
|
optimizer = torch.optim.Adam( list(GFusion_model.parameters()), lr=args.lr) |
|
if args.fidelity_train: |
|
optimizer2 = torch.optim.Adam(fidelity_weight, lr=optimizer.param_groups[0]['lr']*10) |
|
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer2, factor=0.1, patience=args.patience, min_lr=1e-8) |
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8) |
|
|
|
def train(GFusion_model): |
|
epochloss=0 |
|
y_hat, y_true,y_hat_logit = [], [], [] |
|
optimizer.zero_grad() |
|
if args.fidelity_train: optimizer2.zero_grad() |
|
if args.model=="GFusion": |
|
GFusion_model.train() |
|
for i,data in enumerate(train_Loader): |
|
if num_of_fidelities==2: |
|
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source |
|
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source |
|
if args.dataset=='syn': |
|
x1[:,1]=x1[:,1]+x1[:,2] |
|
x1=x1[:,[0,1,3,4]] |
|
x2[:,1]=x2[:,1]+x2[:,2] |
|
x2=x2[:,[0,1,3,4]] |
|
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32) |
|
x2[x2[:,0]>6666,0]=6666 |
|
|
|
datasource=data[0].datasource |
|
Y = target1 |
|
assert(torch.equal(target1,target2)) |
|
Y[Y>6666]=6666 |
|
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device) |
|
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device) |
|
""" |
|
triplets are not the same for graphs when training |
|
""" |
|
num_nodes1=x1.shape[0] |
|
num_nodes2=x2.shape[0] |
|
edge_index_2rd_1, _, _, edx_2nd_1 = triplets(edge_index1, num_nodes1) |
|
edge_index_2rd_2, _, _, edx_2nd_2 = triplets(edge_index2, num_nodes2) |
|
|
|
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\ |
|
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep) |
|
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2] |
|
|
|
if args.dataset=='syn': |
|
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu()) |
|
else: |
|
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu()) |
|
|
|
loss_weight= [weighted_fidelity_weight[i] for i in datasource] |
|
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1),loss_weight) |
|
""" |
|
record predictions |
|
""" |
|
y_hat += list(pred.detach().numpy().reshape(-1)) |
|
y_true += list(Y.detach().numpy().reshape(-1)) |
|
loss=loss1 |
|
loss.backward() |
|
epochloss+=loss |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
if args.fidelity_train: |
|
optimizer2.step() |
|
optimizer2.zero_grad() |
|
reweight_fidelity() |
|
return epochloss.item()/len(train_Loader),y_hat, y_true |
|
|
|
def test(loader,GFusion_model,fidelity_weight): |
|
if not args.single_high: |
|
weighted_fidelity_weight=[i.detach() for i in fidelity_weight] |
|
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)] |
|
fsum=sum(exped_f) |
|
for i in range(num_of_fidelities): |
|
weighted_fidelity_weight[i]=exped_f[i]/fsum |
|
else: |
|
weighted_fidelity_weight=[1,0] |
|
y_hat, y_true,y_hat_logit = [], [], [] |
|
loss_total, pred_num = 0, 0 |
|
GFusion_model.eval() |
|
for i,data in enumerate(loader): |
|
if num_of_fidelities==2: |
|
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source |
|
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source |
|
if args.dataset=='syn': |
|
x1[:,1]=x1[:,1]+x1[:,2] |
|
x1=x1[:,[0,1,3,4]] |
|
x2[:,1]=x2[:,1]+x2[:,2] |
|
x2=x2[:,[0,1,3,4]] |
|
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32) |
|
x2[x2[:,0]>6666,0]=6666 |
|
|
|
datasource=data[0].datasource |
|
Y = target1 |
|
assert(torch.equal(target1,target2)) |
|
Y[Y>6666]=6666 |
|
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device) |
|
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device) |
|
|
|
num_nodes1=x1.shape[0] |
|
num_nodes2=x2.shape[0] |
|
edge_index_2rd_1, num_2nd_neighbors_1, edx_1st_1, edx_2nd_1 = triplets(edge_index1, num_nodes1) |
|
edge_index_2rd_2, num_2nd_neighbors_2, edx_1st_2, edx_2nd_2 = triplets(edge_index2, num_nodes2) |
|
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\ |
|
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep) |
|
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2] |
|
with torch.no_grad(): |
|
if args.dataset=='syn': |
|
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu()) |
|
else: |
|
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu()) |
|
assert(all(datasource==0)) |
|
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1))*weighted_fidelity_weight[0] |
|
""" |
|
record predictions |
|
""" |
|
y_hat += list(pred.detach().numpy().reshape(-1)) |
|
y_true += list(Y.detach().numpy().reshape(-1)) |
|
pred_num += len(Y.reshape(-1, 1)) |
|
loss=loss1 |
|
loss_total += loss.detach() * len(Y.reshape(-1, 1)) |
|
return loss_total/pred_num, y_hat, y_true |
|
if args.loadmodel: |
|
try: |
|
suffix='Oct31-11:50:30' |
|
GFusion_model.load_state_dict(torch.load(os.path.join("save",args.dataset,'model','best_GFusion_model_'+suffix+'.pth')),strict=True) |
|
best_GFusion_model = copy.deepcopy(GFusion_model) |
|
except OSError: |
|
pass |
|
else: |
|
best_val_trigger = 1e3 |
|
old_lr=1e3 |
|
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"), |
|
datetime.now().strftime("%d"), |
|
datetime.now().strftime("%H"), |
|
datetime.now().strftime("%M"), |
|
datetime.now().strftime("%S")) |
|
if args.log: |
|
writer = SummaryWriter(os.path.join(tensorboard_dir,suffix)) |
|
for epoch in range(args.nepoch): |
|
if args.model in ['GFusion']: train_loss,y_hat, y_true=train(GFusion_model) |
|
if args.log: |
|
writer.add_scalar('loss/Train', train_loss, epoch) |
|
if args.dataset in ['south',"north",'syn','flu']: |
|
train_mae=mean_absolute_error(y_true, y_hat) |
|
train_rmse = np.sqrt(mean_squared_error(y_true, y_hat)) |
|
if args.log: |
|
writer.add_scalar('mae/Train', train_mae, epoch) |
|
writer.add_scalar('rmse/Train', train_rmse, epoch) |
|
print(( f"epoch[{epoch:d}] train_loss : {train_loss:.3f} train_mae : {train_mae:.3f} train_rmse : {train_rmse:.3f}" )) |
|
if args.model in ['GFusion']: |
|
if args.fidelity_train==True: |
|
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}") |
|
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}") |
|
if epoch % args.test_per_round == 0: |
|
if args.model in ['GFusion']: |
|
val_loss, yhat_val, ytrue_val = test(val_Loader,GFusion_model,fidelity_weight) |
|
test_loss, yhat_test, ytrue_test = test(test_Loader,GFusion_model,fidelity_weight) |
|
if args.log: |
|
writer.add_scalar('loss/val', val_loss, epoch) |
|
writer.add_scalar('loss/test', test_loss, epoch) |
|
if args.dataset in ['south',"north",'syn','flu']: |
|
val_mae=mean_absolute_error(ytrue_val, yhat_val) |
|
val_rmse = np.sqrt(mean_squared_error(ytrue_val, yhat_val)) |
|
if args.log: |
|
writer.add_scalar('mae/val', val_mae, epoch) |
|
writer.add_scalar('rmse/val', val_rmse, epoch) |
|
print(blue( f"epoch[{epoch:d}] val_mae : {val_mae:.3f} val_rmse : {val_rmse:.3f}" )) |
|
test_mae = mean_absolute_error(ytrue_test, yhat_test) |
|
test_rmse = np.sqrt(mean_squared_error(ytrue_test, yhat_test)) |
|
test_var=explained_variance_score(ytrue_test,yhat_test) |
|
test_coefOfDetermination=r2_score(ytrue_test,yhat_test) |
|
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0] |
|
if args.log: |
|
writer.add_scalar('mae/test', test_mae, epoch) |
|
writer.add_scalar('rmse/test', test_rmse, epoch) |
|
print(blue( f"epoch[{epoch:d}] test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_Pearsonr: {test_Pearsonr:.3f} test_coefOfDetermination: {test_coefOfDetermination:.3f}" )) |
|
if args.model in ['GFusion']: |
|
if args.fidelity_train==True: |
|
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}") |
|
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}") |
|
val_trigger=val_mae |
|
if val_trigger < best_val_trigger: |
|
best_val_trigger = val_trigger |
|
best_GFusion_model = copy.deepcopy(GFusion_model) |
|
best_fidelity=copy.deepcopy(fidelity_weight) |
|
best_info=[epoch,val_trigger] |
|
""" |
|
update lr when epoch≥30 |
|
""" |
|
if epoch >= 30: |
|
lr = scheduler.optimizer.param_groups[0]['lr'] |
|
if old_lr!=lr: |
|
print(red('lr'), epoch, (lr), sep=', ') |
|
old_lr=lr |
|
scheduler.step(val_trigger) |
|
if args.fidelity_train: |
|
scheduler2.step(val_trigger) |
|
val_loss, yhat_val, ytrue_val = test(val_Loader,best_GFusion_model,best_fidelity) |
|
test_loss, yhat_test, ytrue_test = test(test_Loader,best_GFusion_model,best_fidelity) |
|
if args.dataset in ['south',"north",'syn','flu']: |
|
val_mae = mean_absolute_error(ytrue_val, yhat_val) |
|
val_rmse=np.sqrt(mean_squared_error(ytrue_val,yhat_val)) |
|
val_var=explained_variance_score(ytrue_val,yhat_val) |
|
print(blue( f"best_val val_mae: {val_mae:.3f} val_rmse: {val_rmse:.3f} val_var: {val_var:.3f}" )) |
|
|
|
test_mae=mean_absolute_error(ytrue_test,yhat_test) |
|
test_rmse=np.sqrt(mean_squared_error(ytrue_test,yhat_test)) |
|
test_var=explained_variance_score(ytrue_test,yhat_test) |
|
test_coefOfDetermination=r2_score(ytrue_test,yhat_test) |
|
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0] |
|
print(blue( f"best_test test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_var: {test_var:.3f}" )) |
|
if not args.loadmodel: |
|
""" |
|
save training info and best result |
|
""" |
|
result_file=os.path.join(info_dir, suffix) |
|
with open(result_file, 'w') as f: |
|
print(args.num_neighbors,args.nepoch,sep=' ',file=f) |
|
print(f"fidelity weight: {best_fidelity[0]:.3f}, {best_fidelity[1]:.3f}",file=f) |
|
print("Random Seed: ", Seed,file=f) |
|
if args.dataset in ['south',"north",'syn','flu']: |
|
print(f"MAE val : {val_mae:.3f}, Test : {test_mae:.3f}", file=f) |
|
print(f"rmse val : {val_rmse:.3f}, Test : {test_rmse:.3f}", file=f) |
|
print(f"var val : {val_var:.3f}, Test : {test_var:.3f}", file=f) |
|
print(f"test_coefOfDetermination: {test_coefOfDetermination:.3f}, test_Pearsonr : {test_Pearsonr:.3f}", file=f) |
|
print(f"Best info: {best_info}", file=f) |
|
for i in [[a,getattr(args, a)] for a in args.__dict__]: |
|
print(i,sep='\n',file=f) |
|
with open(os.path.join(model_dir,'best_f_weight'+"_"+suffix+".pkl"), 'wb') as handle: |
|
pickle.dump(fidelity_weight, handle) |
|
torch.save(best_GFusion_model.state_dict(), os.path.join(model_dir,'best_GFusion_model'+"_"+suffix+'.pth') ) |
|
print("done") |
|
|
|
if __name__ == '__main__': |
|
args = get_args() |
|
if not os.path.exists(args.save_dir): |
|
os.makedirs(args.save_dir,exist_ok=True) |
|
tensorboard_dir=os.path.join(args.save_dir,'log') |
|
if not os.path.exists(tensorboard_dir): |
|
os.makedirs(tensorboard_dir,exist_ok=True) |
|
model_dir=os.path.join(args.save_dir,'model') |
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir,exist_ok=True) |
|
info_dir=os.path.join(args.save_dir,'info') |
|
if not os.path.exists(info_dir): |
|
os.makedirs(info_dir,exist_ok=True) |
|
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000) |
|
print("Random Seed: ", Seed) |
|
random.seed(Seed) |
|
torch.manual_seed(Seed) |
|
np.random.seed(Seed) |
|
flag=0 |
|
if args.dataset in ['south',"north",'syn',"flu"]: |
|
graphs1,graphs2=dataset.load_point(args.dataset,args.num_neighbors,[False,200,500]) |
|
np.random.shuffle(graphs1) |
|
val_test_split = int(np.around( 2 / 10 * len(graphs1) )) |
|
train_val_split = int(len(graphs1)-2*val_test_split) |
|
if args.single_high: |
|
train_graphs = graphs1[:train_val_split] |
|
else: |
|
train_graphs = graphs1[:train_val_split]+graphs2 |
|
val_graphs = graphs1[train_val_split:train_val_split+val_test_split] |
|
test_graphs = graphs1[train_val_split+val_test_split:] |
|
|
|
np.random.shuffle(train_graphs) |
|
train_Loader=DataLoader(train_graphs, batch_size=args.batchSize) |
|
val_Loader=DataLoader(val_graphs, batch_size=args.batchSize) |
|
test_Loader=DataLoader(test_graphs, batch_size=args.batchSize) |
|
print(f"train_pair_num: {len(train_graphs)}, val_pair_num: {len(val_graphs)}, test_pair_num: {len(test_graphs)}") |
|
else: |
|
raise Exception('Dataset not recognized.') |
|
main(args,train_Loader,val_Loader,test_Loader) |
|
|