import pickle import torch import numpy as np from torch.utils.data import Dataset, DataLoader from torch.nn.functional import one_hot,softmax import matplotlib.pyplot as plt import random import torch.utils.data as data teacher_forcing_ratio = 0.5 def get_data_set(mode="train"): NT,AT = None,None with open(r'dataset/'+"NT.pkl","rb") as f1: NT = pickle.load(f1) with open(r'dataset/'+"AT.pkl","rb") as f2: AT = pickle.load(f2) def onehot_encode(char, vocab): # one hot encode a given text encoded = [0 for _ in range(len(vocab))] #[0,0,1,0,000] encoded[vocab.index(char)] = 1 return encoded from read_graph import get_graph import networkx as nx G_new = get_graph() voc = list(G_new) with open('nodes.pkl', 'wb') as f: pickle.dump(voc, f, pickle.HIGHEST_PROTOCOL) voc = None with open('nodes.pkl', 'rb') as f: voc = pickle.load(f) voc.append(0) # 补全符号 voc.append('s') # START voc.append('e') # EOF total_word_count = len(voc) # 轨迹的标签 samples = [] labels = [] if mode=="train": for tr in NT: samples.append(tr) labels.append(1) # 正常 else: for tr in NT: samples.append(tr) labels.append(1) # 正常 for tr in AT: samples.append(tr) labels.append(0) # 异常 def padding(x,max_length): if len(x) > max_length: text = x[:max_length] else: text = x + [[0,0]] * (max_length - len(x)) return text # 计算最长轨迹 max_len = 10 for tr in samples: max_len = max(max_len,len(tr)) samples_padded = [] # 补全为长轨迹 for tr in samples: tr = padding(tr,max_len) samples_padded.append(tr) # One hot def onehot_encode(char, vocab): # one hot encode a given text encoded = [0 for _ in range(len(vocab))] if char != 0: encoded[vocab.index(char)] = 1 return encoded samples_one_hot = [] samples_index = [] for tr in samples_padded: tr_rep = [] tr_rep_index = [] for pt in tr: spatial = onehot_encode(pt[0], voc) temporal = int(pt[1]) tr_rep.append(spatial) tr_rep_index.append(voc.index(pt[0])) samples_one_hot.append(tr_rep) samples_index.append(tr_rep_index) sampletensor = torch.Tensor(samples_one_hot) sampletensor_index = torch.Tensor(samples_index) labeltensor = torch.Tensor(labels) # print("sampletensor.shape",sampletensor.shape) # print("labeltensor.shape",labeltensor.shape) return sampletensor,sampletensor_index,labeltensor,max_len global device if torch.cuda.is_available(): torch.backends.cudnn.enabled = False device = torch.device("cuda:0") torch.cuda.set_device(0) import os os.environ['CUDA_VISIBLE_DEVICES']='0' print("Working on GPU") torch.cuda.empty_cache() else: device = torch.device("cpu") import torch.nn as nn # from VAE import AE,RNN if __name__ == '__main__': sampletensor,sampletensor_index,labeltensor,max_len = get_data_set("train") batch_size = 2 train_set = data.TensorDataset(sampletensor, sampletensor_index,labeltensor) train_iter = data.DataLoader(train_set, batch_size, shuffle=False, drop_last=False) # rnn = RNN(input_size=2694,hidden_size=64,batch_size=2,maxlen=max_len) # loss = nn.CrossEntropyLoss() # optimizer = torch.optim.Adamax(rnn.parameters(),lr=1e-2) # # net = rnn.to(device) # num_epochs = 120 # # h_hat_avg = None # # from tqdm import tqdm # for epoch in tqdm(range(num_epochs)): # epoch_total_loss = 0 # for x, x_label,y in train_iter: # # RNN # xhat,kld,h_hat = net(x,x,"train",None) # # print(xhat.shape) # # print(x_label.shape) # len_all = (x_label.shape[0])*(x_label.shape[1]) # xhat = xhat.reshape(len_all,-1) # x_label = x_label.reshape(len_all).long().to(device) # # print(x_label) # # print("xhat",xhat.shape) # # print("x_label",x_label.shape) # l = loss(xhat,x_label) # # print("reconstruction loss:",l,"kld loss:",kld) # total_loss = l + kld # epoch_total_loss += total_loss # optimizer.zero_grad() # total_loss.backward() # optimizer.step() # if epoch == num_epochs - 1: # if h_hat_avg is None: # h_hat_avg = h_hat/ torch.full(h_hat.shape,len(sampletensor)).to(device) # else: # h_hat_avg += h_hat / torch.full(h_hat.shape, len(sampletensor)).to(device) # print(">>> h_hat_avg",h_hat_avg.shape) # print(" epoch_total_loss = ",epoch_total_loss) # # print("training ends") # torch.save(net,"LSTM-VAE.pth") # torch.save(h_hat_avg, 'h_hat_avg.pt') # # # # # # # # # # # #