asammoud
Re-add large CSVs using Git LFS
b265364
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from util.time import *
from util.env import *
from sklearn.metrics import mean_squared_error
from pipeline.test import *
import torch.nn.functional as F
import numpy as np
from pipeline.evaluate import get_best_performance_data, get_val_performance_data, get_full_err_scores
from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score
from torch.utils.data import DataLoader, random_split, Subset
from scipy.stats import iqr
def loss_func(y_pred, y_true):
loss = F.mse_loss(y_pred, y_true, reduction='mean')
return loss
def train(model = None, save_path = '', config={}, train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):
seed = config['seed']
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
now = time.time()
train_loss_list = []
cmp_loss_list = []
device = get_device()
acu_loss = 0
min_loss = 1e+8
min_f1 = 0
min_pre = 0
best_prec = 0
i = 0
epoch = config['epoch']
early_stop_win = 15
model.train()
log_interval = 1000
stop_improve_count = 0
dataloader = train_dataloader
for i_epoch in range(epoch):
acu_loss = 0
model.train()
for x, labels, attack_labels, edge_index in dataloader:
_start = time.time()
x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]]
optimizer.zero_grad()
out = model(x, edge_index).float().to(device)
loss = loss_func(out, labels)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
acu_loss += loss.item()
i += 1
# each epoch
print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format(
i_epoch, epoch,
acu_loss/len(dataloader), acu_loss), flush=True
)
# use val dataset to judge
if val_dataloader is not None:
val_loss, val_result = test(model, val_dataloader)
if val_loss < min_loss:
torch.save(model.state_dict(), save_path)
min_loss = val_loss
stop_improve_count = 0
else:
stop_improve_count += 1
if stop_improve_count >= early_stop_win:
break
else:
if acu_loss < min_loss :
torch.save(model.state_dict(), save_path)
min_loss = acu_loss
return train_loss_list