Dzy6's picture
init
c7995e9
raw
history blame
23.4 kB
import argparse
import os
import random
import torch
import pandas as pd
import numpy as np
import time
import torch.optim as optim
from matplotlib import cm
import matplotlib.pyplot as plt
import json
from model import GFusion
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops
from torch.nn.functional import softmax
from torch_geometric.nn import knn_graph
import copy
torch.autograd.set_detect_anomaly(True)
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc
from sklearn.feature_selection import r_regression
import pickle
from utils.utils import triplets,unique,pos2key
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import dataset
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
blue = lambda x: '\033[94m' + x + '\033[0m'
red = lambda x: '\033[31m' + x + '\033[0m'
green = lambda x: '\033[32m' + x + '\033[0m'
yellow = lambda x: '\033[33m' + x + '\033[0m'
greenline = lambda x: '\033[42m' + x + '\033[0m'
yellowline = lambda x: '\033[43m' + x + '\033[0m'
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--log', type=str, default="True")
parser.add_argument('--loadmodel', type=str, default="False")
parser.add_argument('--split_dataset', type=str, default="False")
parser.add_argument('--model', type=str, default="GFusion")
# ablation
parser.add_argument('--edge_rep', type=str, default="True")
parser.add_argument('--single_high', type=str, default="False")
parser.add_argument('--fidelity_train', type=str, default="True")
parser.add_argument('--fidelity_low_weight', type=float, default=-1.0)
parser.add_argument('--share', type=str, default="101")
parser.add_argument('--dataset', type=str, default='flu')
parser.add_argument('--manualSeed', type=str, default="False")
parser.add_argument('--man_seed', type=int, default=12345)
parser.add_argument('--test_per_round', type=int, default=10)
parser.add_argument('--patience', type=int, default=30) #scheduler
parser.add_argument('--nepoch', type=int, default=201)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--activation', type=str, default='relu')#'lrelu'
parser.add_argument('--batchSize', type=int, default=512)
parser.add_argument('--num_neighbors', type=int, default=3)
parser.add_argument('--regression_loss', type=str, default='l2')
parser.add_argument('--h_ch', type=int, default=16)
parser.add_argument('--localdepth', type=int, default=1) # mlp(distance) mlp(theta) >=1
parser.add_argument('--num_interactions', type=int, default=1) #>=1
parser.add_argument('--finaldepth', type=int, default=3) # mlp(concat node_attr and geo_encoding)
args = parser.parse_args()
args.log=True if args.log=="True" else False
args.loadmodel=True if args.loadmodel=="True" else False
args.split_dataset=True if args.split_dataset=="True" else False
args.edge_rep=True if args.edge_rep=="True" else False
args.single_high=True if args.single_high=="True" else False
args.fidelity_train=True if args.fidelity_train=="True" and args.single_high is False and args.fidelity_low_weight==-1.0 else False
args.manualSeed=True if args.manualSeed=="True" else False
args.save_dir=os.path.join('./save/',args.dataset)
return args
def main(args,train_Loader,val_Loader,test_Loader):
if flag:
return
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
measure_Pearsonr=r_regression
criterion_l1 = torch.nn.L1Loss() #reduction='sum'
criterion_l2 = torch.nn.MSELoss()
criterion=criterion_l1 if args.regression_loss=='l1' else criterion_l2
if args.model in ['GFusion']:
def myL1(pred,true,weight=None,reduction='mean'):
loss=(abs(pred-true))
num=len(pred)
if weight is not None:
loss=[weight[i]*loss[i] for i in range(num)]
loss=sum(loss)
if reduction=='mean':
loss=loss/num
return loss
def myL2(pred,true,weight=None,reduction='mean'):
loss=((pred-true)**2)
num=len(pred)
if weight is not None:
loss=[weight[i]*loss[i] for i in range(num)]
loss=sum(loss)
if reduction=='mean':
loss=loss/num
return loss
criterion=myL1 if args.regression_loss=='l1' else myL2
num_of_fidelities=len(train_graphs[0])
def reweight_fidelity():
if args.single_high:
weighted_fidelity_weight[0]=1
weighted_fidelity_weight[1]=0
elif args.fidelity_low_weight!=-1.0:
weighted_fidelity_weight[0]=1
weighted_fidelity_weight[1]=args.fidelity_low_weight
else:
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
fsum=sum(exped_f)
for i in range(num_of_fidelities):
weighted_fidelity_weight[i]=exped_f[i]/fsum
fidelity_weight,weighted_fidelity_weight=[],[]
if args.dataset in ['south',"north","flu"]:
for i in range(num_of_fidelities):
fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
weighted_fidelity_weight+=[0]
elif args.dataset in ["syn"]:
fidelity_weight=[torch.tensor(1,dtype=torch.float32).requires_grad_(),torch.tensor(0.0,dtype=torch.float32).requires_grad_()]
for i in range(num_of_fidelities):
# fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
weighted_fidelity_weight+=[0]
reweight_fidelity()
if args.dataset in ['south',"north"]:
x_in=30
elif args.dataset in ['flu']:
x_in=0
elif args.dataset=='syn':
x_in=1
else:
raise Exception('Dataset not recognized.')
if args.model=="GFusion":
GFusion_model=GFusion(h_channel=args.h_ch,input_featuresize=x_in,\
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share)
GFusion_model.to(device)
optimizer = torch.optim.Adam( list(GFusion_model.parameters()), lr=args.lr)
if args.fidelity_train:
optimizer2 = torch.optim.Adam(fidelity_weight, lr=optimizer.param_groups[0]['lr']*10)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer2, factor=0.1, patience=args.patience, min_lr=1e-8)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
def train(GFusion_model):
epochloss=0
y_hat, y_true,y_hat_logit = [], [], []
optimizer.zero_grad()
if args.fidelity_train: optimizer2.zero_grad()
if args.model=="GFusion":
GFusion_model.train()
for i,data in enumerate(train_Loader):
if num_of_fidelities==2:
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
if args.dataset=='syn':
x1[:,1]=x1[:,1]+x1[:,2]
x1=x1[:,[0,1,3,4]]
x2[:,1]=x2[:,1]+x2[:,2]
x2=x2[:,[0,1,3,4]]
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
x2[x2[:,0]>6666,0]=6666
# edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
datasource=data[0].datasource
Y = target1
assert(torch.equal(target1,target2))
Y[Y>6666]=6666
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
"""
triplets are not the same for graphs when training
"""
num_nodes1=x1.shape[0]
num_nodes2=x2.shape[0]
edge_index_2rd_1, _, _, edx_2nd_1 = triplets(edge_index1, num_nodes1)
edge_index_2rd_2, _, _, edx_2nd_2 = triplets(edge_index2, num_nodes2)
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
if args.dataset=='syn':
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
else:
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
loss_weight= [weighted_fidelity_weight[i] for i in datasource]
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1),loss_weight)
"""
record predictions
"""
y_hat += list(pred.detach().numpy().reshape(-1))
y_true += list(Y.detach().numpy().reshape(-1))
loss=loss1
loss.backward()
epochloss+=loss
optimizer.step()
optimizer.zero_grad()
if args.fidelity_train:
optimizer2.step()
optimizer2.zero_grad()
reweight_fidelity()
return epochloss.item()/len(train_Loader),y_hat, y_true
def test(loader,GFusion_model,fidelity_weight):
if not args.single_high:
weighted_fidelity_weight=[i.detach() for i in fidelity_weight]
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
fsum=sum(exped_f)
for i in range(num_of_fidelities):
weighted_fidelity_weight[i]=exped_f[i]/fsum
else:
weighted_fidelity_weight=[1,0]
y_hat, y_true,y_hat_logit = [], [], []
loss_total, pred_num = 0, 0
GFusion_model.eval()
for i,data in enumerate(loader):
if num_of_fidelities==2:
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
if args.dataset=='syn':
x1[:,1]=x1[:,1]+x1[:,2]
x1=x1[:,[0,1,3,4]]
x2[:,1]=x2[:,1]+x2[:,2]
x2=x2[:,[0,1,3,4]]
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
x2[x2[:,0]>6666,0]=6666
# edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
datasource=data[0].datasource
Y = target1
assert(torch.equal(target1,target2))
Y[Y>6666]=6666
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
num_nodes1=x1.shape[0]
num_nodes2=x2.shape[0]
edge_index_2rd_1, num_2nd_neighbors_1, edx_1st_1, edx_2nd_1 = triplets(edge_index1, num_nodes1)
edge_index_2rd_2, num_2nd_neighbors_2, edx_1st_2, edx_2nd_2 = triplets(edge_index2, num_nodes2)
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
with torch.no_grad():
if args.dataset=='syn':
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
else:
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
assert(all(datasource==0))
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1))*weighted_fidelity_weight[0]
"""
record predictions
"""
y_hat += list(pred.detach().numpy().reshape(-1))
y_true += list(Y.detach().numpy().reshape(-1))
pred_num += len(Y.reshape(-1, 1))
loss=loss1
loss_total += loss.detach() * len(Y.reshape(-1, 1))
return loss_total/pred_num, y_hat, y_true
if args.loadmodel:
try:
suffix='Oct31-11:50:30'
GFusion_model.load_state_dict(torch.load(os.path.join("save",args.dataset,'model','best_GFusion_model_'+suffix+'.pth')),strict=True)
best_GFusion_model = copy.deepcopy(GFusion_model)
except OSError:
pass
else:
best_val_trigger = 1e3
old_lr=1e3
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
datetime.now().strftime("%d"),
datetime.now().strftime("%H"),
datetime.now().strftime("%M"),
datetime.now().strftime("%S"))
if args.log:
writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
for epoch in range(args.nepoch):
if args.model in ['GFusion']: train_loss,y_hat, y_true=train(GFusion_model)
if args.log:
writer.add_scalar('loss/Train', train_loss, epoch)
if args.dataset in ['south',"north",'syn','flu']:
train_mae=mean_absolute_error(y_true, y_hat)
train_rmse = np.sqrt(mean_squared_error(y_true, y_hat))
if args.log:
writer.add_scalar('mae/Train', train_mae, epoch)
writer.add_scalar('rmse/Train', train_rmse, epoch)
print(( f"epoch[{epoch:d}] train_loss : {train_loss:.3f} train_mae : {train_mae:.3f} train_rmse : {train_rmse:.3f}" ))
if args.model in ['GFusion']:
if args.fidelity_train==True:
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
if epoch % args.test_per_round == 0:
if args.model in ['GFusion']:
val_loss, yhat_val, ytrue_val = test(val_Loader,GFusion_model,fidelity_weight)
test_loss, yhat_test, ytrue_test = test(test_Loader,GFusion_model,fidelity_weight)
if args.log:
writer.add_scalar('loss/val', val_loss, epoch)
writer.add_scalar('loss/test', test_loss, epoch)
if args.dataset in ['south',"north",'syn','flu']:
val_mae=mean_absolute_error(ytrue_val, yhat_val)
val_rmse = np.sqrt(mean_squared_error(ytrue_val, yhat_val))
if args.log:
writer.add_scalar('mae/val', val_mae, epoch)
writer.add_scalar('rmse/val', val_rmse, epoch)
print(blue( f"epoch[{epoch:d}] val_mae : {val_mae:.3f} val_rmse : {val_rmse:.3f}" ))
test_mae = mean_absolute_error(ytrue_test, yhat_test)
test_rmse = np.sqrt(mean_squared_error(ytrue_test, yhat_test))
test_var=explained_variance_score(ytrue_test,yhat_test)
test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
if args.log:
writer.add_scalar('mae/test', test_mae, epoch)
writer.add_scalar('rmse/test', test_rmse, epoch)
print(blue( f"epoch[{epoch:d}] test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_Pearsonr: {test_Pearsonr:.3f} test_coefOfDetermination: {test_coefOfDetermination:.3f}" ))
if args.model in ['GFusion']:
if args.fidelity_train==True:
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
val_trigger=val_mae
if val_trigger < best_val_trigger:
best_val_trigger = val_trigger
best_GFusion_model = copy.deepcopy(GFusion_model)
best_fidelity=copy.deepcopy(fidelity_weight)
best_info=[epoch,val_trigger]
"""
update lr when epoch≥30
"""
if epoch >= 30:
lr = scheduler.optimizer.param_groups[0]['lr']
if old_lr!=lr:
print(red('lr'), epoch, (lr), sep=', ')
old_lr=lr
scheduler.step(val_trigger)
if args.fidelity_train:
scheduler2.step(val_trigger)
val_loss, yhat_val, ytrue_val = test(val_Loader,best_GFusion_model,best_fidelity)
test_loss, yhat_test, ytrue_test = test(test_Loader,best_GFusion_model,best_fidelity)
if args.dataset in ['south',"north",'syn','flu']:
val_mae = mean_absolute_error(ytrue_val, yhat_val)
val_rmse=np.sqrt(mean_squared_error(ytrue_val,yhat_val))
val_var=explained_variance_score(ytrue_val,yhat_val)
print(blue( f"best_val val_mae: {val_mae:.3f} val_rmse: {val_rmse:.3f} val_var: {val_var:.3f}" ))
test_mae=mean_absolute_error(ytrue_test,yhat_test)
test_rmse=np.sqrt(mean_squared_error(ytrue_test,yhat_test))
test_var=explained_variance_score(ytrue_test,yhat_test)
test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
print(blue( f"best_test test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_var: {test_var:.3f}" ))
if not args.loadmodel:
"""
save training info and best result
"""
result_file=os.path.join(info_dir, suffix)
with open(result_file, 'w') as f:
print(args.num_neighbors,args.nepoch,sep=' ',file=f)
print(f"fidelity weight: {best_fidelity[0]:.3f}, {best_fidelity[1]:.3f}",file=f)
print("Random Seed: ", Seed,file=f)
if args.dataset in ['south',"north",'syn','flu']:
print(f"MAE val : {val_mae:.3f}, Test : {test_mae:.3f}", file=f)
print(f"rmse val : {val_rmse:.3f}, Test : {test_rmse:.3f}", file=f)
print(f"var val : {val_var:.3f}, Test : {test_var:.3f}", file=f)
print(f"test_coefOfDetermination: {test_coefOfDetermination:.3f}, test_Pearsonr : {test_Pearsonr:.3f}", file=f)
print(f"Best info: {best_info}", file=f)
for i in [[a,getattr(args, a)] for a in args.__dict__]:
print(i,sep='\n',file=f)
with open(os.path.join(model_dir,'best_f_weight'+"_"+suffix+".pkl"), 'wb') as handle:
pickle.dump(fidelity_weight, handle)
torch.save(best_GFusion_model.state_dict(), os.path.join(model_dir,'best_GFusion_model'+"_"+suffix+'.pth') )
print("done")
if __name__ == '__main__':
args = get_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir,exist_ok=True)
tensorboard_dir=os.path.join(args.save_dir,'log')
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir,exist_ok=True)
model_dir=os.path.join(args.save_dir,'model')
if not os.path.exists(model_dir):
os.makedirs(model_dir,exist_ok=True)
info_dir=os.path.join(args.save_dir,'info')
if not os.path.exists(info_dir):
os.makedirs(info_dir,exist_ok=True)
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
print("Random Seed: ", Seed)
random.seed(Seed)
torch.manual_seed(Seed)
np.random.seed(Seed)
flag=0
if args.dataset in ['south',"north",'syn',"flu"]:
graphs1,graphs2=dataset.load_point(args.dataset,args.num_neighbors,[False,200,500])
np.random.shuffle(graphs1)
val_test_split = int(np.around( 2 / 10 * len(graphs1) ))
train_val_split = int(len(graphs1)-2*val_test_split)
if args.single_high:
train_graphs = graphs1[:train_val_split]
else:
train_graphs = graphs1[:train_val_split]+graphs2
val_graphs = graphs1[train_val_split:train_val_split+val_test_split]
test_graphs = graphs1[train_val_split+val_test_split:]
np.random.shuffle(train_graphs)
train_Loader=DataLoader(train_graphs, batch_size=args.batchSize)
val_Loader=DataLoader(val_graphs, batch_size=args.batchSize)
test_Loader=DataLoader(test_graphs, batch_size=args.batchSize)
print(f"train_pair_num: {len(train_graphs)}, val_pair_num: {len(val_graphs)}, test_pair_num: {len(test_graphs)}")
else:
raise Exception('Dataset not recognized.')
main(args,train_Loader,val_Loader,test_Loader)