|
|
import numpy as np
|
|
|
import torch
|
|
|
import matplotlib.pyplot as plt
|
|
|
import torch.nn as nn
|
|
|
import time
|
|
|
from util.time import *
|
|
|
from util.env import *
|
|
|
from sklearn.metrics import mean_squared_error
|
|
|
from pipeline.test import *
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
from pipeline.evaluate import get_best_performance_data, get_val_performance_data, get_full_err_scores
|
|
|
from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score
|
|
|
from torch.utils.data import DataLoader, random_split, Subset
|
|
|
from scipy.stats import iqr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def loss_func(y_pred, y_true):
|
|
|
loss = F.mse_loss(y_pred, y_true, reduction='mean')
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def train(model = None, save_path = '', config={}, train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):
|
|
|
|
|
|
seed = config['seed']
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
|
|
|
|
|
|
now = time.time()
|
|
|
|
|
|
train_loss_list = []
|
|
|
cmp_loss_list = []
|
|
|
|
|
|
device = get_device()
|
|
|
|
|
|
|
|
|
acu_loss = 0
|
|
|
min_loss = 1e+8
|
|
|
min_f1 = 0
|
|
|
min_pre = 0
|
|
|
best_prec = 0
|
|
|
|
|
|
i = 0
|
|
|
epoch = config['epoch']
|
|
|
early_stop_win = 15
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
log_interval = 1000
|
|
|
stop_improve_count = 0
|
|
|
|
|
|
dataloader = train_dataloader
|
|
|
|
|
|
for i_epoch in range(epoch):
|
|
|
|
|
|
acu_loss = 0
|
|
|
model.train()
|
|
|
|
|
|
for x, labels, attack_labels, edge_index in dataloader:
|
|
|
_start = time.time()
|
|
|
|
|
|
x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]]
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
out = model(x, edge_index).float().to(device)
|
|
|
loss = loss_func(out, labels)
|
|
|
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
train_loss_list.append(loss.item())
|
|
|
acu_loss += loss.item()
|
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
|
|
print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format(
|
|
|
i_epoch, epoch,
|
|
|
acu_loss/len(dataloader), acu_loss), flush=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if val_dataloader is not None:
|
|
|
|
|
|
val_loss, val_result = test(model, val_dataloader)
|
|
|
|
|
|
if val_loss < min_loss:
|
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
min_loss = val_loss
|
|
|
stop_improve_count = 0
|
|
|
else:
|
|
|
stop_improve_count += 1
|
|
|
|
|
|
|
|
|
if stop_improve_count >= early_stop_win:
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
if acu_loss < min_loss :
|
|
|
torch.save(model.state_dict(), save_path)
|
|
|
min_loss = acu_loss
|
|
|
|
|
|
|
|
|
|
|
|
return train_loss_list
|
|
|
|